Skip to content

Commit

Permalink
Waiting for inputs reimplemented
Browse files Browse the repository at this point in the history
  • Loading branch information
hovsep committed Oct 3, 2024
1 parent b6d4d50 commit 730104e
Show file tree
Hide file tree
Showing 8 changed files with 173 additions and 18 deletions.
35 changes: 31 additions & 4 deletions component/activation_result.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package component

import (
"errors"
"fmt"
)

Expand Down Expand Up @@ -30,6 +31,12 @@ const (

// ActivationCodePanicked : component is activated, but panicked
ActivationCodePanicked

// ActivationCodeWaitingForInputs : component waits for specific inputs, but all input signals in current activation cycle may be cleared (default behaviour)

Check notice on line 35 in component/activation_result.go

View workflow job for this annotation

GitHub Actions / qodana

Comment of exported element starts with the incorrect name

Comment should have the following format 'ActivationCodeWaitingForInputsClear ...' (with an optional leading article)

Check notice on line 35 in component/activation_result.go

View workflow job for this annotation

GitHub Actions / qodana

Comment of exported element starts with the incorrect name

Comment should have the following format 'ActivationCodeWaitingForInputsClear ...' (with an optional leading article)
ActivationCodeWaitingForInputsClear

// ActivationCodeWaitingForInputsKeep : component waits for specific inputs, but wants to keep current input signals for the next cycle
ActivationCodeWaitingForInputsKeep
)

// NewActivationResult creates a new activation result for given component
Expand Down Expand Up @@ -60,13 +67,13 @@ func (ar *ActivationResult) Code() ActivationResultCode {
return ar.code
}

// HasError returns true when activation result has an error
func (ar *ActivationResult) HasError() bool {
// IsError returns true when activation result has an error
func (ar *ActivationResult) IsError() bool {
return ar.code == ActivationCodeReturnedError && ar.Error() != nil
}

// HasPanic returns true when activation result is derived from panic
func (ar *ActivationResult) HasPanic() bool {
// IsPanic returns true when activation result is derived from panic
func (ar *ActivationResult) IsPanic() bool {
return ar.code == ActivationCodePanicked && ar.Error() != nil
}

Expand Down Expand Up @@ -125,3 +132,23 @@ func (c *Component) newActivationResultPanicked(err error) *ActivationResult {
WithActivationCode(ActivationCodePanicked).
WithError(err)
}

func (c *Component) newActivationResultWaitingForInputs(err error) *ActivationResult {
activationCode := ActivationCodeWaitingForInputsClear
if errors.Is(err, errWaitingForInputsKeep) {
activationCode = ActivationCodeWaitingForInputsKeep
}
return NewActivationResult(c.Name()).
SetActivated(true).
WithActivationCode(activationCode).
WithError(err)
}

func IsWaitingForInput(activationResult *ActivationResult) bool {
return activationResult.Code() == ActivationCodeWaitingForInputsClear ||
activationResult.Code() == ActivationCodeWaitingForInputsKeep
}

func WantsToKeepInputs(activationResult *ActivationResult) bool {
return activationResult.Code() == ActivationCodeWaitingForInputsKeep
}
4 changes: 2 additions & 2 deletions component/activation_result_collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func (collection ActivationResultCollection) Add(activationResults ...*Activatio
// HasErrors tells whether the collection contains at least one activation result with error and respective code
func (collection ActivationResultCollection) HasErrors() bool {
for _, ar := range collection {
if ar.HasError() {
if ar.IsError() {
return true
}
}
Expand All @@ -29,7 +29,7 @@ func (collection ActivationResultCollection) HasErrors() bool {
// HasPanics tells whether the collection contains at least one activation result with panic and respective code
func (collection ActivationResultCollection) HasPanics() bool {
for _, ar := range collection {
if ar.HasPanic() {
if ar.IsPanic() {
return true
}
}
Expand Down
18 changes: 11 additions & 7 deletions component/component.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package component

import (
"errors"
"fmt"
"github.com/hovsep/fmesh/port"
)
Expand Down Expand Up @@ -89,10 +90,6 @@ func (c *Component) hasActivationFunction() bool {
// MaybeActivate tries to run the activation function if all required conditions are met
// @TODO: hide this method from user
func (c *Component) MaybeActivate() (activationResult *ActivationResult) {
defer func() {
c.Inputs().Clear()
}()

defer func() {
if r := recover(); r != nil {
activationResult = c.newActivationResultPanicked(fmt.Errorf("panicked with: %v", r))
Expand All @@ -102,7 +99,6 @@ func (c *Component) MaybeActivate() (activationResult *ActivationResult) {
if !c.hasActivationFunction() {
//Activation function is not set (maybe useful while the mesh is under development)
activationResult = c.newActivationResultNoFunction()

return
}

Expand All @@ -115,14 +111,17 @@ func (c *Component) MaybeActivate() (activationResult *ActivationResult) {
//Invoke the activation func
err := c.f(c.Inputs(), c.Outputs())

if errors.Is(err, errWaitingForInputs) {
activationResult = c.newActivationResultWaitingForInputs(err)
return
}

if err != nil {
activationResult = c.newActivationResultReturnedError(err)

return
}

activationResult = c.newActivationResultOK()

return
}

Expand All @@ -132,3 +131,8 @@ func (c *Component) FlushOutputs() {
out.Flush()
}
}

// ClearInputs clears all input ports
func (c *Component) ClearInputs() {
c.Inputs().Clear()
}
4 changes: 2 additions & 2 deletions component/component_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -501,10 +501,10 @@ func TestComponent_MaybeActivate(t *testing.T) {
assert.Equal(t, tt.wantActivationResult.Activated(), gotActivationResult.Activated())
assert.Equal(t, tt.wantActivationResult.ComponentName(), gotActivationResult.ComponentName())
assert.Equal(t, tt.wantActivationResult.Code(), gotActivationResult.Code())
if tt.wantActivationResult.HasError() {
if tt.wantActivationResult.IsError() {
assert.EqualError(t, gotActivationResult.Error(), tt.wantActivationResult.Error().Error())
} else {
assert.False(t, gotActivationResult.HasError())
assert.False(t, gotActivationResult.IsError())
}

})
Expand Down
19 changes: 19 additions & 0 deletions component/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package component

import (
"errors"
"fmt"
)

var (
errWaitingForInputs = errors.New("component is waiting for some inputs")
errWaitingForInputsKeep = fmt.Errorf("%w: do not clear input ports", errWaitingForInputs)
)

// NewErrWaitForInputs returns respective error
func NewErrWaitForInputs(keepInputs bool) error {
if keepInputs {
return errWaitingForInputsKeep
}
return errWaitingForInputs
}
12 changes: 11 additions & 1 deletion fmesh.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,17 @@ func (fm *FMesh) drainComponents(cycle *cycle.Cycle) {
continue
}

if component.IsWaitingForInput(activationResult) {
if !component.WantsToKeepInputs(activationResult) {
c.ClearInputs()
}
// Components waiting for inputs are not flushed
continue
}

// Normally components are fully drained
c.FlushOutputs()
c.ClearInputs()
}
}

Expand All @@ -127,7 +137,7 @@ func (fm *FMesh) Run() (cycle.Collection, error) {
}

func (fm *FMesh) mustStop(cycleResult *cycle.Cycle, cycleNum int) (bool, error) {
if (fm.config.CyclesLimit > 0) && (cycleNum >= fm.config.CyclesLimit) {
if (fm.config.CyclesLimit > 0) && (cycleNum > fm.config.CyclesLimit) {
return true, ErrReachedMaxAllowedCycles
}

Expand Down
4 changes: 2 additions & 2 deletions fmesh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -601,10 +601,10 @@ func TestFMesh_Run(t *testing.T) {
assert.Equal(t, tt.want[i].ActivationResults()[componentName].ComponentName(), gotActivationResult.ComponentName())
assert.Equal(t, tt.want[i].ActivationResults()[componentName].Code(), gotActivationResult.Code())

if tt.want[i].ActivationResults()[componentName].HasError() {
if tt.want[i].ActivationResults()[componentName].IsError() {
assert.EqualError(t, tt.want[i].ActivationResults()[componentName].Error(), gotActivationResult.Error().Error())
} else {
assert.False(t, gotActivationResult.HasError())
assert.False(t, gotActivationResult.IsError())
}
}
}
Expand Down
95 changes: 95 additions & 0 deletions integration_tests/ports/waiting_for_inputs_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package ports

import (
"github.com/hovsep/fmesh"
"github.com/hovsep/fmesh/component"
"github.com/hovsep/fmesh/cycle"
"github.com/hovsep/fmesh/port"
"github.com/hovsep/fmesh/signal"
"github.com/stretchr/testify/assert"
"testing"
)

func Test_WaitingForInputs(t *testing.T) {
tests := []struct {
name string
setupFM func() *fmesh.FMesh
setInputs func(fm *fmesh.FMesh)
assertions func(t *testing.T, fm *fmesh.FMesh, cycles cycle.Collection, err error)
}{
{
name: "waiting for longer chain",
setupFM: func() *fmesh.FMesh {
getDoubler := func(name string) *component.Component {
return component.New(name).
WithDescription("This component just doubles the input").
WithInputs("i1").
WithOutputs("o1").
WithActivationFunc(func(inputs port.Collection, outputs port.Collection) error {
inputNum := inputs.ByName("i1").Signals().FirstPayload().(int)
outputs.ByName("o1").PutSignals(signal.New(inputNum * 2))
return nil
})
}

d1 := getDoubler("d1")
d2 := getDoubler("d2")
d3 := getDoubler("d3")
d4 := getDoubler("d4")
d5 := getDoubler("d5")

s := component.New("sum").
WithDescription("This component just sums 2 inputs").
WithInputs("i1", "i2").
WithOutputs("o1").
WithActivationFunc(func(inputs port.Collection, outputs port.Collection) error {
if !inputs.ByNames("i1", "i2").AllHaveSignals() {
return component.NewErrWaitForInputs(true)
}

inputNum1 := inputs.ByName("i1").Signals().FirstPayload().(int)
inputNum2 := inputs.ByName("i2").Signals().FirstPayload().(int)
outputs.ByName("o1").PutSignals(signal.New(inputNum1 + inputNum2))
return nil
})

//This chain consist of 3 components: d1->d2->d3
d1.Outputs().ByName("o1").PipeTo(d2.Inputs().ByName("i1"))
d2.Outputs().ByName("o1").PipeTo(d3.Inputs().ByName("i1"))

//This chain has only 2: d4->d5
d4.Outputs().ByName("o1").PipeTo(d5.Inputs().ByName("i1"))

//Both chains go into summator
d3.Outputs().ByName("o1").PipeTo(s.Inputs().ByName("i1"))
d5.Outputs().ByName("o1").PipeTo(s.Inputs().ByName("i2"))

return fmesh.New("fm").
WithComponents(d1, d2, d3, d4, d5, s).
WithConfig(fmesh.Config{
ErrorHandlingStrategy: fmesh.StopOnFirstErrorOrPanic,
CyclesLimit: 5,
})

},
setInputs: func(fm *fmesh.FMesh) {
//Put 1 signal to each chain so they start in the same cycle
fm.Components().ByName("d1").Inputs().ByName("i1").PutSignals(signal.New(1))
fm.Components().ByName("d4").Inputs().ByName("i1").PutSignals(signal.New(2))
},
assertions: func(t *testing.T, fm *fmesh.FMesh, cycles cycle.Collection, err error) {
assert.NoError(t, err)
result := fm.Components().ByName("sum").Outputs().ByName("o1").Signals().FirstPayload().(int)
assert.Equal(t, 16, result)
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
fm := tt.setupFM()
tt.setInputs(fm)
cycles, err := fm.Run()
tt.assertions(t, fm, cycles, err)
})
}
}

0 comments on commit 730104e

Please sign in to comment.