diff --git a/main.go b/main.go index 97ab9a6..4e9159c 100644 --- a/main.go +++ b/main.go @@ -66,7 +66,8 @@ func mnistTrain(net *Network) { inputs := make([]float64, net.inputs) for i := range inputs { - x, _ := strconv.ParseFloat(record[i], 64) + // skip the first column (index 0) containing the label + x, _ := strconv.ParseFloat(record[i+1], 64) inputs[i] = (x / 255.0 * 0.999) + 0.001 } @@ -99,10 +100,8 @@ func mnistPredict(net *Network) { } inputs := make([]float64, net.inputs) for i := range inputs { - if i == 0 { - inputs[i] = 1.0 - } - x, _ := strconv.ParseFloat(record[i], 64) + // skip the first column (index 0) which contains the label + x, _ := strconv.ParseFloat(record[i+1], 64) inputs[i] = (x / 255.0 * 0.999) + 0.001 } outputs := net.Predict(inputs)