-
Notifications
You must be signed in to change notification settings - Fork 69
ExampleMains
Here are examples of code that can be put into neural2d.cpp in main() to drive the neural net operation.
The three conceptual "modes" should really be called "use cases," as noted below.
TRAINING mode is when the input samples are labelled with target output values, and we run the backprop algorithm after each sample to update the weights:
// Here is an example of TRAINING mode -------------:
myNet.eta = 0.1;
myNet.reportEveryNth = 25;
myNet.repeatInputSamples = true;
myNet.doneErrorThreshold = 0.005;
do {
if (myNet.shuffleInputSamples) {
myNet.sampleSet.shuffle();
}
for (auto &sample : myNet.sampleSet.samples) {
myNet.feedForward(sample);
myNet.backProp(sample);
myNet.reportResults(sample);
if (myNet.recentAverageError < myNet.doneErrorThreshold) {
cout << "Solved! -- Saving weights..." << endl;
myNet.saveWeights(weightsFilename);
// Do whatever else needs to be done here
exit(0);
}
}
} while (myNet.repeatInputSamples);
VALIDATE mode is when we have already trained a net; the input samples are labelled with target output values; and we only report the accuracy of the outputs; we don't use backprop to update the weights.
// Here is an example of VALIDATE mode -------------:
myNet.reportEveryNth = 1;
myNet.repeatInputSamples = false;
myNet.loadWeights(weightsFilename); // Use weights from a trained net
do {
for (auto &sample : myNet.sampleSet.samples) {
myNet.feedForward(sample);
myNet.reportResults(sample);
}
} while (myNet.repeatInputSamples);
TRAINED mode is when we have already trained a net; the input samples are NOT labelled with target output values; we can only report the neural net's outputs; we can't report accuracy. There is no difference between this code and the code for the VALIDATE mode. The only difference is in the inputData.txt config file: if it contains target output values, they will be used to report the net's accuracy; if the target values are not in the inputData.txt file, only the net's outputs will be reported.
// Here is an example of TRAINED mode -------------:
myNet.reportEveryNth = 1;
myNet.repeatInputSamples = false;
myNet.loadWeights(weightsFilename); // Use weights from a trained net
do {
for (auto &sample : myNet.sampleSet.samples) {
myNet.feedForward(sample);
myNet.reportResults(sample);
}
} while (myNet.repeatInputSamples);