-
-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #64 from yaricom/fast-network-solver-serialization
Implementing serialization/deserialization of the model for `FastModularNetworkSolver`
- Loading branch information
Showing
6 changed files
with
308 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
package network | ||
|
||
import ( | ||
"encoding/json" | ||
"github.com/yaricom/goNEAT/v4/neat/math" | ||
"io" | ||
) | ||
|
||
// WriteModel is to write this FastModularNetworkSolver as a model to be used later. | ||
func (s *FastModularNetworkSolver) WriteModel(w io.Writer) error { | ||
dataHolder := newFastModularNetworkSolverData(s) | ||
enc := json.NewEncoder(w) | ||
return enc.Encode(dataHolder) | ||
} | ||
|
||
// ReadFMNSModel allows loading model encoding FastModularNetworkSolver. | ||
func ReadFMNSModel(reader io.Reader) (*FastModularNetworkSolver, error) { | ||
var data fastModularNetworkSolverData | ||
dec := json.NewDecoder(reader) | ||
if err := dec.Decode(&data); err != nil { | ||
return nil, err | ||
} | ||
activationFunctions := make([]math.NodeActivationType, len(data.ActivationFunctions)) | ||
for i, f := range data.ActivationFunctions { | ||
activationFunctions[i] = f.NodeActivation | ||
} | ||
var modules []*FastControlNode | ||
if len(data.Modules) > 0 { | ||
modules = make([]*FastControlNode, len(data.Modules)) | ||
for i, m := range data.Modules { | ||
modules[i] = &FastControlNode{ | ||
ActivationType: m.ActivationType.NodeActivation, | ||
InputIndexes: m.InputIndexes, | ||
OutputIndexes: m.OutputIndexes, | ||
} | ||
} | ||
} | ||
fmns := NewFastModularNetworkSolver( | ||
data.BiasNeuronCount, data.InputNeuronCount, data.OutputNeuronCount, | ||
data.TotalNeuronCount, activationFunctions, | ||
data.Connections, data.BiasList, modules, | ||
) | ||
fmns.Name = data.Name | ||
fmns.Id = data.Id | ||
return fmns, nil | ||
} | ||
|
||
type NodeActivator struct { | ||
NodeActivation math.NodeActivationType | ||
} | ||
|
||
type fastControlNodeData struct { | ||
ActivationType NodeActivator `json:"activation_type"` | ||
InputIndexes []int `json:"input_indexes"` | ||
OutputIndexes []int `json:"output_indexes"` | ||
} | ||
|
||
type fastModularNetworkSolverData struct { | ||
Id int `json:"id"` | ||
Name string `json:"name"` | ||
InputNeuronCount int `json:"input_neuron_count"` | ||
SensorNeuronCount int `json:"sensor_neuron_count"` | ||
OutputNeuronCount int `json:"output_neuron_count"` | ||
BiasNeuronCount int `json:"bias_neuron_count"` | ||
TotalNeuronCount int `json:"total_neuron_count"` | ||
ActivationFunctions []NodeActivator `json:"activation_functions"` | ||
BiasList []float64 `json:"bias_list"` | ||
Connections []*FastNetworkLink `json:"connections"` | ||
Modules []fastControlNodeData `json:"modules,omitempty"` | ||
} | ||
|
||
func newFastModularNetworkSolverData(n *FastModularNetworkSolver) *fastModularNetworkSolverData { | ||
data := &fastModularNetworkSolverData{ | ||
Id: n.Id, | ||
Name: n.Name, | ||
InputNeuronCount: n.inputNeuronCount, | ||
SensorNeuronCount: n.sensorNeuronCount, | ||
OutputNeuronCount: n.outputNeuronCount, | ||
BiasNeuronCount: n.biasNeuronCount, | ||
TotalNeuronCount: n.totalNeuronCount, | ||
ActivationFunctions: make([]NodeActivator, len(n.activationFunctions)), | ||
BiasList: n.biasList, | ||
Connections: n.connections, | ||
Modules: make([]fastControlNodeData, 0), | ||
} | ||
for i, v := range n.activationFunctions { | ||
data.ActivationFunctions[i] = NodeActivator{ | ||
NodeActivation: v, | ||
} | ||
} | ||
if n.modules != nil { | ||
for _, v := range n.modules { | ||
data.Modules = append(data.Modules, fastControlNodeData{ | ||
ActivationType: NodeActivator{NodeActivation: v.ActivationType}, | ||
InputIndexes: v.InputIndexes, | ||
OutputIndexes: v.OutputIndexes, | ||
}) | ||
} | ||
} | ||
return data | ||
} | ||
|
||
func (n *NodeActivator) MarshalText() ([]byte, error) { | ||
if activationName, err := math.NodeActivators.ActivationNameFromType(n.NodeActivation); err != nil { | ||
return nil, err | ||
} else { | ||
return []byte(activationName), nil | ||
} | ||
} | ||
|
||
func (n *NodeActivator) UnmarshalText(text []byte) (err error) { | ||
n.NodeActivation, err = math.NodeActivators.ActivationTypeFromName(string(text)) | ||
return err | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
package network | ||
|
||
import ( | ||
"bytes" | ||
"encoding/json" | ||
"github.com/stretchr/testify/assert" | ||
"github.com/stretchr/testify/require" | ||
"testing" | ||
) | ||
|
||
const jsonFMNStr = `{"id":123456,"name":"test network","input_neuron_count":2,"sensor_neuron_count":3,"output_neuron_count":2,"bias_neuron_count":1,"total_neuron_count":8,"activation_functions":["SigmoidSteepenedActivation","SigmoidSteepenedActivation","SigmoidSteepenedActivation","SigmoidSteepenedActivation","SigmoidSteepenedActivation","SigmoidSteepenedActivation","SigmoidSteepenedActivation","SigmoidSteepenedActivation"],"bias_list":[0,0,0,0,0,0,1,0],"connections":[{"source_index":1,"target_index":5,"weight":15,"signal":0},{"source_index":2,"target_index":5,"weight":10,"signal":0},{"source_index":2,"target_index":6,"weight":5,"signal":0},{"source_index":6,"target_index":7,"weight":17,"signal":0},{"source_index":5,"target_index":3,"weight":7,"signal":0},{"source_index":7,"target_index":3,"weight":4.5,"signal":0},{"source_index":7,"target_index":4,"weight":13,"signal":0}]}` | ||
const jsonFNMStrModule = `{"id":123456,"name":"test network","input_neuron_count":2,"sensor_neuron_count":3,"output_neuron_count":2,"bias_neuron_count":1,"total_neuron_count":8,"activation_functions":["SigmoidSteepenedActivation","SigmoidSteepenedActivation","SigmoidSteepenedActivation","LinearActivation","LinearActivation","LinearActivation","LinearActivation","NullActivation"],"bias_list":[0,0,0,0,0,10,1,0],"connections":[{"source_index":1,"target_index":5,"weight":15,"signal":0},{"source_index":2,"target_index":6,"weight":5,"signal":0},{"source_index":7,"target_index":3,"weight":4.5,"signal":0},{"source_index":7,"target_index":4,"weight":13,"signal":0}],"modules":[{"activation_type":"MultiplyModuleActivation","input_indexes":[5,6],"output_indexes":[7]}]}` | ||
|
||
const networkName = "test network" | ||
const networkId = 123456 | ||
|
||
func TestFastModularNetworkSolver_WriteModel_NoModule(t *testing.T) { | ||
net := buildNamedNetwork(networkName, networkId) | ||
|
||
fmm, err := net.FastNetworkSolver() | ||
require.NoError(t, err, "failed to create fast network solver") | ||
|
||
outBuf := bytes.NewBufferString("") | ||
err = fmm.(*FastModularNetworkSolver).WriteModel(outBuf) | ||
require.NoError(t, err, "failed to write model") | ||
|
||
println(outBuf.String()) | ||
|
||
var expected interface{} | ||
err = json.Unmarshal([]byte(jsonFMNStr), &expected) | ||
require.NoError(t, err, "failed to unmarshal expected json") | ||
var actual interface{} | ||
err = json.Unmarshal(outBuf.Bytes(), &actual) | ||
require.NoError(t, err, "failed to unmarshal actual json") | ||
|
||
assert.EqualValues(t, expected, actual, "model JSON does not match expected JSON") | ||
} | ||
|
||
func TestFastModularNetworkSolver_WriteModel_WithModule(t *testing.T) { | ||
net := buildNamedModularNetwork(networkName, networkId) | ||
|
||
fmm, err := net.FastNetworkSolver() | ||
require.NoError(t, err, "failed to create fast network solver") | ||
|
||
outBuf := bytes.NewBufferString("") | ||
err = fmm.(*FastModularNetworkSolver).WriteModel(outBuf) | ||
require.NoError(t, err, "failed to write model") | ||
|
||
println(outBuf.String()) | ||
|
||
var expected interface{} | ||
err = json.Unmarshal([]byte(jsonFNMStrModule), &expected) | ||
require.NoError(t, err, "failed to unmarshal expected json") | ||
var actual interface{} | ||
err = json.Unmarshal(outBuf.Bytes(), &actual) | ||
require.NoError(t, err, "failed to unmarshal actual json") | ||
|
||
assert.EqualValues(t, expected, actual, "model JSON does not match expected JSON") | ||
} | ||
|
||
func TestReadFMNSModel_NoModule(t *testing.T) { | ||
buf := bytes.NewBufferString(jsonFMNStr) | ||
|
||
fmm, err := ReadFMNSModel(buf) | ||
assert.NoError(t, err, "failed to read model") | ||
assert.NotNil(t, fmm, "failed to deserialize model") | ||
|
||
assert.Equal(t, fmm.Name, networkName, "wrong network name") | ||
assert.Equal(t, fmm.Id, networkId, "wrong network id") | ||
|
||
data := []float64{1.5, 2.0} // bias inherent | ||
err = fmm.LoadSensors(data) | ||
require.NoError(t, err, "failed to load sensors") | ||
|
||
// test that it operates as expected | ||
// | ||
net := buildNetwork() | ||
depth, err := net.MaxActivationDepth() | ||
require.NoError(t, err, "failed to calculate max depth") | ||
|
||
t.Logf("depth: %d\n", depth) | ||
logNetworkActivationPath(net, t) | ||
|
||
data = append(data, 1.0) // BIAS is third object | ||
err = net.LoadSensors(data) | ||
require.NoError(t, err, "failed to load sensors") | ||
res, err := net.ForwardSteps(depth) | ||
require.NoError(t, err, "error when trying to activate objective network") | ||
require.True(t, res, "failed to activate objective network") | ||
|
||
// do forward steps through the solver and test results | ||
// | ||
res, err = fmm.Relax(depth, .1) | ||
require.NoError(t, err, "error while do forward steps") | ||
require.True(t, res, "forward steps returned false") | ||
|
||
// check results by comparing activations of objective network and fast network solver | ||
// | ||
outputs := fmm.ReadOutputs() | ||
for i, out := range outputs { | ||
assert.Equal(t, net.Outputs[i].Activation, out, "wrong activation at: %d", i) | ||
} | ||
} | ||
|
||
func TestReadFMNSModel_ModularNetwork(t *testing.T) { | ||
buf := bytes.NewBufferString(jsonFNMStrModule) | ||
|
||
fmm, err := ReadFMNSModel(buf) | ||
assert.NoError(t, err, "failed to read model") | ||
assert.NotNil(t, fmm, "failed to deserialize model") | ||
|
||
assert.Equal(t, fmm.Name, networkName, "wrong network name") | ||
assert.Equal(t, fmm.Id, networkId, "wrong network id") | ||
|
||
data := []float64{1.0, 2.0} // bias inherent | ||
err = fmm.LoadSensors(data) | ||
require.NoError(t, err, "failed to load sensors") | ||
|
||
// test that it operates as expected | ||
// | ||
net := buildModularNetwork() | ||
depth, err := net.MaxActivationDepth() | ||
require.NoError(t, err, "failed to calculate max depth") | ||
|
||
t.Logf("depth: %d\n", depth) | ||
logNetworkActivationPath(net, t) | ||
|
||
// activate objective network | ||
// | ||
data = append(data, 1.0) // BIAS is third object | ||
err = net.LoadSensors(data) | ||
require.NoError(t, err, "failed to load sensors") | ||
res, err := net.ForwardSteps(depth) | ||
require.NoError(t, err, "error when trying to activate objective network") | ||
require.True(t, res, "failed to activate objective network") | ||
|
||
// do forward steps through the solver and test results | ||
// | ||
res, err = fmm.Relax(depth, 1) | ||
require.NoError(t, err, "error while do forward steps") | ||
require.True(t, res, "forward steps returned false") | ||
|
||
// check results by comparing activations of objective network and fast network solver | ||
// | ||
outputs := fmm.ReadOutputs() | ||
for i, out := range outputs { | ||
assert.Equal(t, net.Outputs[i].Activation, out, "wrong activation at: %d", i) | ||
} | ||
|
||
} | ||
func buildNamedNetwork(name string, id int) *Network { | ||
net := buildNetwork() | ||
net.Name = name | ||
net.Id = id | ||
return net | ||
} | ||
|
||
func buildNamedModularNetwork(name string, id int) *Network { | ||
net := buildModularNetwork() | ||
net.Name = name | ||
net.Id = id | ||
return net | ||
} |
Oops, something went wrong.