diff --git a/neat/network/network.go b/neat/network/network.go index e85ce95..99be624 100644 --- a/neat/network/network.go +++ b/neat/network/network.go @@ -327,7 +327,11 @@ func (n *Network) ForwardSteps(steps int) (res bool, err error) { } func (n *Network) RecursiveSteps() (bool, error) { - return false, errors.New("RecursiveSteps is not implemented") + netDepth, err := n.MaxActivationDepthFast(0) + if err != nil { + return false, err + } + return n.ForwardSteps(netDepth) } func (n *Network) Relax(_ int, _ float64) (bool, error) { diff --git a/neat/network/network_test.go b/neat/network/network_test.go index 8f5a53e..2a9d423 100644 --- a/neat/network/network_test.go +++ b/neat/network/network_test.go @@ -216,12 +216,32 @@ func TestNetwork_ForwardSteps(t *testing.T) { assert.NoError(t, err) assert.True(t, res) + expectedOuts := []float64{1.0, 1.0} + assert.EqualValues(t, expectedOuts, net.ReadOutputs()) + // test zero steps res, err = net.ForwardSteps(0) assert.EqualError(t, err, ErrZeroActivationStepsRequested.Error()) assert.False(t, res) } +func TestNetwork_RecursiveSteps(t *testing.T) { + net := buildNetwork() + + data := []float64{0.5, 0.0, 1.5} + err := net.LoadSensors(data) + require.NoError(t, err, "failed to load sensors") + + relaxed, err := net.RecursiveSteps() + assert.NoError(t, err) + assert.True(t, relaxed) + + logNetworkActivationPath(net, t) + + expectedOuts := []float64{1.0, 1.0} + assert.EqualValues(t, expectedOuts, net.ReadOutputs()) +} + func TestNetwork_ForwardSteps_disconnected(t *testing.T) { net := buildDisconnectedNetwork()