diff --git a/component/activation_result.go b/component/activation_result.go index 281fccc..9a64eda 100644 --- a/component/activation_result.go +++ b/component/activation_result.go @@ -1,6 +1,7 @@ package component import ( + "errors" "fmt" ) @@ -30,6 +31,12 @@ const ( // ActivationCodePanicked : component is activated, but panicked ActivationCodePanicked + + // ActivationCodeWaitingForInputs : component waits for specific inputs, but all input signals in current activation cycle may be cleared (default behaviour) + ActivationCodeWaitingForInputsClear + + // ActivationCodeWaitingForInputsKeep : component waits for specific inputs, but wants to keep current input signals for the next cycle + ActivationCodeWaitingForInputsKeep ) // NewActivationResult creates a new activation result for given component @@ -60,13 +67,13 @@ func (ar *ActivationResult) Code() ActivationResultCode { return ar.code } -// HasError returns true when activation result has an error -func (ar *ActivationResult) HasError() bool { +// IsError returns true when activation result has an error +func (ar *ActivationResult) IsError() bool { return ar.code == ActivationCodeReturnedError && ar.Error() != nil } -// HasPanic returns true when activation result is derived from panic -func (ar *ActivationResult) HasPanic() bool { +// IsPanic returns true when activation result is derived from panic +func (ar *ActivationResult) IsPanic() bool { return ar.code == ActivationCodePanicked && ar.Error() != nil } @@ -125,3 +132,23 @@ func (c *Component) newActivationResultPanicked(err error) *ActivationResult { WithActivationCode(ActivationCodePanicked). WithError(err) } + +func (c *Component) newActivationResultWaitingForInputs(err error) *ActivationResult { + activationCode := ActivationCodeWaitingForInputsClear + if errors.Is(err, errWaitingForInputsKeep) { + activationCode = ActivationCodeWaitingForInputsKeep + } + return NewActivationResult(c.Name()). + SetActivated(true). + WithActivationCode(activationCode). + WithError(err) +} + +func IsWaitingForInput(activationResult *ActivationResult) bool { + return activationResult.Code() == ActivationCodeWaitingForInputsClear || + activationResult.Code() == ActivationCodeWaitingForInputsKeep +} + +func WantsToKeepInputs(activationResult *ActivationResult) bool { + return activationResult.Code() == ActivationCodeWaitingForInputsKeep +} diff --git a/component/activation_result_collection.go b/component/activation_result_collection.go index 4d7219d..257eff8 100644 --- a/component/activation_result_collection.go +++ b/component/activation_result_collection.go @@ -19,7 +19,7 @@ func (collection ActivationResultCollection) Add(activationResults ...*Activatio // HasErrors tells whether the collection contains at least one activation result with error and respective code func (collection ActivationResultCollection) HasErrors() bool { for _, ar := range collection { - if ar.HasError() { + if ar.IsError() { return true } } @@ -29,7 +29,7 @@ func (collection ActivationResultCollection) HasErrors() bool { // HasPanics tells whether the collection contains at least one activation result with panic and respective code func (collection ActivationResultCollection) HasPanics() bool { for _, ar := range collection { - if ar.HasPanic() { + if ar.IsPanic() { return true } } diff --git a/component/component.go b/component/component.go index 4d7014c..2161d70 100644 --- a/component/component.go +++ b/component/component.go @@ -1,6 +1,7 @@ package component import ( + "errors" "fmt" "github.com/hovsep/fmesh/port" ) @@ -89,10 +90,6 @@ func (c *Component) hasActivationFunction() bool { // MaybeActivate tries to run the activation function if all required conditions are met // @TODO: hide this method from user func (c *Component) MaybeActivate() (activationResult *ActivationResult) { - defer func() { - c.Inputs().Clear() - }() - defer func() { if r := recover(); r != nil { activationResult = c.newActivationResultPanicked(fmt.Errorf("panicked with: %v", r)) @@ -102,7 +99,6 @@ func (c *Component) MaybeActivate() (activationResult *ActivationResult) { if !c.hasActivationFunction() { //Activation function is not set (maybe useful while the mesh is under development) activationResult = c.newActivationResultNoFunction() - return } @@ -115,14 +111,17 @@ func (c *Component) MaybeActivate() (activationResult *ActivationResult) { //Invoke the activation func err := c.f(c.Inputs(), c.Outputs()) + if errors.Is(err, errWaitingForInputs) { + activationResult = c.newActivationResultWaitingForInputs(err) + return + } + if err != nil { activationResult = c.newActivationResultReturnedError(err) - return } activationResult = c.newActivationResultOK() - return } @@ -132,3 +131,8 @@ func (c *Component) FlushOutputs() { out.Flush() } } + +// ClearInputs clears all input ports +func (c *Component) ClearInputs() { + c.Inputs().Clear() +} diff --git a/component/component_test.go b/component/component_test.go index 44b8417..0c8da6a 100644 --- a/component/component_test.go +++ b/component/component_test.go @@ -501,10 +501,10 @@ func TestComponent_MaybeActivate(t *testing.T) { assert.Equal(t, tt.wantActivationResult.Activated(), gotActivationResult.Activated()) assert.Equal(t, tt.wantActivationResult.ComponentName(), gotActivationResult.ComponentName()) assert.Equal(t, tt.wantActivationResult.Code(), gotActivationResult.Code()) - if tt.wantActivationResult.HasError() { + if tt.wantActivationResult.IsError() { assert.EqualError(t, gotActivationResult.Error(), tt.wantActivationResult.Error().Error()) } else { - assert.False(t, gotActivationResult.HasError()) + assert.False(t, gotActivationResult.IsError()) } }) diff --git a/component/errors.go b/component/errors.go new file mode 100644 index 0000000..671df5f --- /dev/null +++ b/component/errors.go @@ -0,0 +1,19 @@ +package component + +import ( + "errors" + "fmt" +) + +var ( + errWaitingForInputs = errors.New("component is waiting for some inputs") + errWaitingForInputsKeep = fmt.Errorf("%w: do not clear input ports", errWaitingForInputs) +) + +// NewErrWaitForInputs returns respective error +func NewErrWaitForInputs(keepInputs bool) error { + if keepInputs { + return errWaitingForInputsKeep + } + return errWaitingForInputs +} diff --git a/fmesh.go b/fmesh.go index 7fecc08..7736097 100644 --- a/fmesh.go +++ b/fmesh.go @@ -106,7 +106,17 @@ func (fm *FMesh) drainComponents(cycle *cycle.Cycle) { continue } + if component.IsWaitingForInput(activationResult) { + if !component.WantsToKeepInputs(activationResult) { + c.ClearInputs() + } + // Components waiting for inputs are not flushed + continue + } + + // Normally components are fully drained c.FlushOutputs() + c.ClearInputs() } } @@ -127,7 +137,7 @@ func (fm *FMesh) Run() (cycle.Collection, error) { } func (fm *FMesh) mustStop(cycleResult *cycle.Cycle, cycleNum int) (bool, error) { - if (fm.config.CyclesLimit > 0) && (cycleNum >= fm.config.CyclesLimit) { + if (fm.config.CyclesLimit > 0) && (cycleNum > fm.config.CyclesLimit) { return true, ErrReachedMaxAllowedCycles } diff --git a/fmesh_test.go b/fmesh_test.go index e656a45..f773528 100644 --- a/fmesh_test.go +++ b/fmesh_test.go @@ -601,10 +601,10 @@ func TestFMesh_Run(t *testing.T) { assert.Equal(t, tt.want[i].ActivationResults()[componentName].ComponentName(), gotActivationResult.ComponentName()) assert.Equal(t, tt.want[i].ActivationResults()[componentName].Code(), gotActivationResult.Code()) - if tt.want[i].ActivationResults()[componentName].HasError() { + if tt.want[i].ActivationResults()[componentName].IsError() { assert.EqualError(t, tt.want[i].ActivationResults()[componentName].Error(), gotActivationResult.Error().Error()) } else { - assert.False(t, gotActivationResult.HasError()) + assert.False(t, gotActivationResult.IsError()) } } } diff --git a/integration_tests/ports/waiting_for_inputs_test.go b/integration_tests/ports/waiting_for_inputs_test.go new file mode 100644 index 0000000..9c790c2 --- /dev/null +++ b/integration_tests/ports/waiting_for_inputs_test.go @@ -0,0 +1,95 @@ +package ports + +import ( + "github.com/hovsep/fmesh" + "github.com/hovsep/fmesh/component" + "github.com/hovsep/fmesh/cycle" + "github.com/hovsep/fmesh/port" + "github.com/hovsep/fmesh/signal" + "github.com/stretchr/testify/assert" + "testing" +) + +func Test_WaitingForInputs(t *testing.T) { + tests := []struct { + name string + setupFM func() *fmesh.FMesh + setInputs func(fm *fmesh.FMesh) + assertions func(t *testing.T, fm *fmesh.FMesh, cycles cycle.Collection, err error) + }{ + { + name: "waiting for longer chain", + setupFM: func() *fmesh.FMesh { + getDoubler := func(name string) *component.Component { + return component.New(name). + WithDescription("This component just doubles the input"). + WithInputs("i1"). + WithOutputs("o1"). + WithActivationFunc(func(inputs port.Collection, outputs port.Collection) error { + inputNum := inputs.ByName("i1").Signals().FirstPayload().(int) + outputs.ByName("o1").PutSignals(signal.New(inputNum * 2)) + return nil + }) + } + + d1 := getDoubler("d1") + d2 := getDoubler("d2") + d3 := getDoubler("d3") + d4 := getDoubler("d4") + d5 := getDoubler("d5") + + s := component.New("sum"). + WithDescription("This component just sums 2 inputs"). + WithInputs("i1", "i2"). + WithOutputs("o1"). + WithActivationFunc(func(inputs port.Collection, outputs port.Collection) error { + if !inputs.ByNames("i1", "i2").AllHaveSignals() { + return component.NewErrWaitForInputs(true) + } + + inputNum1 := inputs.ByName("i1").Signals().FirstPayload().(int) + inputNum2 := inputs.ByName("i2").Signals().FirstPayload().(int) + outputs.ByName("o1").PutSignals(signal.New(inputNum1 + inputNum2)) + return nil + }) + + //This chain consist of 3 components: d1->d2->d3 + d1.Outputs().ByName("o1").PipeTo(d2.Inputs().ByName("i1")) + d2.Outputs().ByName("o1").PipeTo(d3.Inputs().ByName("i1")) + + //This chain has only 2: d4->d5 + d4.Outputs().ByName("o1").PipeTo(d5.Inputs().ByName("i1")) + + //Both chains go into summator + d3.Outputs().ByName("o1").PipeTo(s.Inputs().ByName("i1")) + d5.Outputs().ByName("o1").PipeTo(s.Inputs().ByName("i2")) + + return fmesh.New("fm"). + WithComponents(d1, d2, d3, d4, d5, s). + WithConfig(fmesh.Config{ + ErrorHandlingStrategy: fmesh.StopOnFirstErrorOrPanic, + CyclesLimit: 5, + }) + + }, + setInputs: func(fm *fmesh.FMesh) { + //Put 1 signal to each chain so they start in the same cycle + fm.Components().ByName("d1").Inputs().ByName("i1").PutSignals(signal.New(1)) + fm.Components().ByName("d4").Inputs().ByName("i1").PutSignals(signal.New(2)) + }, + assertions: func(t *testing.T, fm *fmesh.FMesh, cycles cycle.Collection, err error) { + assert.NoError(t, err) + result := fm.Components().ByName("sum").Outputs().ByName("o1").Signals().FirstPayload().(int) + assert.Equal(t, 16, result) + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fm := tt.setupFM() + tt.setInputs(fm) + cycles, err := fm.Run() + tt.assertions(t, fm, cycles, err) + }) + } +}