-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.cpp
87 lines (80 loc) · 2.69 KB
/
main.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#include <iostream>
#include <armadillo>
#include <tclap/CmdLine.h>
#include "activations.hpp"
#include "nn.hpp"
using namespace std;
int main(int argc, char** argv) {
TCLAP::CmdLine cmd("Simple fully connected neural network tool", ' ', "0.1");
TCLAP::ValueArg<std::string> cmdConfig("c", "config", "Config file path", true, "", "path");
TCLAP::ValueArg<std::string> cmdModel("m", "model", "Load model from file", true, "", "path");
cmd.xorAdd(cmdConfig, cmdModel);
TCLAP::ValueArg<std::string> cmdSave("s", "save", "Save model to path", false, "", "path");
cmd.add(cmdSave);
TCLAP::SwitchArg cmdTest("t","test","Test model", false);
TCLAP::SwitchArg cmdPredict("p","predict","Make predictions", false);
cmd.add(cmdTest);
cmd.add(cmdPredict);
cmd.parse(argc, argv);
std::string fname;
bool isDump;
if (cmdConfig.isSet()) {
fname = cmdConfig.getValue();
isDump = false;
} else {
fname = cmdModel.getValue();
isDump = true;
}
NeuralNetwork nn(fname, isDump);
if (cmdTest.getValue()) { // Testing mode
std::string line;
double err_sum = 0.0;
int count = 0;
while (std::getline(std::cin, line)) {
std::istringstream iss(line);
RowVec x(nn.XSize);
RowVec y(nn.YSize);
for (size_t i = 0; i < nn.XSize; ++i) {
iss >> x[i];
}
for (size_t i = 0; i < nn.YSize; ++i) {
iss >> y[i];
}
RowVec delta_y = nn.Predict(x) - y;
err_sum += arma::dot(delta_y, delta_y);
count += 1;
}
std::cout << "RMSE like: " << sqrt(err_sum / count) << std::endl;
std::cout << "Per output RMSE: " << sqrt(err_sum / count / nn.YSize) << std::endl;
} else if (cmdPredict.getValue()) { // Make predictions
std::string line;
while (std::getline(std::cin, line))
{
std::istringstream iss(line);
RowVec x(nn.XSize);
for (size_t i = 0; i < nn.XSize; ++i) {
iss >> x[i];
}
std::cout << nn.Predict(x);
}
} else { // Training mode (default)
std::string line;
while (std::getline(std::cin, line))
{
std::istringstream iss(line);
RowVec x(nn.XSize);
RowVec y(nn.YSize);
for (size_t i = 0; i < nn.XSize; ++i) {
iss >> x[i];
}
for (size_t i = 0; i < nn.YSize; ++i) {
iss >> y[i];
}
nn.Train(x, y);
}
}
if (cmdSave.isSet()) {
nn.SaveModel(cmdSave.getValue(), true);
}
return 0;
}