彻底弄懂BP反向传播,15行代码,C++实现也简单,MNIST分类98.54 %精度
使用C++实现图像分类,10轮nmist识别率达到98.54 %
基于mnist数据集的BP算法,专业
大量的备注,为了你学懂
代码简单,核心代码15行
仅仅依赖OpenBLAS(为了速度),如果不依赖太慢了。可以学习下
权重初始化,凯明初始化,fan_in + fan_out
ReLU激活函数
SGD Momentum
学习率衰减
loss函数中log的上下溢控制
sigmoid的exp上溢控制
下载数据集/依赖项,并解压编译
bash ./download.sh这个环节干了两件事
第一下载mnist数据集
第二下载OpenBLAS依赖项,并且解压后编译
运行
make test 测试查看效果
make image 加载workspace/5.bmp文件识别
make train 训练模型
// 前向部分
auto x = choice_rows(train_norm_images, image_indexs, ibatch * batch_size, batch_size);
auto y = choice_rows(train_onehot_labels, image_indexs, ibatch * batch_size, batch_size);
auto hidden = gemm_mul(x, false , input_to_hidden, false ) + hidden_bias;
auto hidden_act = hidden.relu();
auto output = gemm_mul(hidden_act, false , hidden_to_output, false ) + output_bias;
auto probability = output.sigmoid();
float loss = compute_loss(probability, y);
// 反向部分
auto doutput = (probability - y) / batch_size;
auto doutput_bias = row_sum(doutput);
auto dhidden_to_output = gemm_mul(hidden_act, true , doutput, false );
auto dhidden_act = gemm_mul(doutput, false , hidden_to_output, true );
auto dhidden = delta_relu(dhidden_act, hidden);
auto dinput_to_hidden = gemm_mul(x, true , dhidden, false );
auto dhidden_bias = row_sum(dhidden);
optim.update_params(...);
下载OpenBLAS ,编译后放到lean目录下,确保Makefile中路径匹配正确
下载MNIST数据集 放到workspace目录下
执行make train
训练模型,然后make test
进行模型测试
cd workspace
Help:
./pro train 执行训练
./pro test 执行测试
./pro image 5.bmp 加载28x28的bmp图像文件进行预测
我们的博客:http://www.zifuture.com:8090/archives/bp
我们的B站:https://space.bilibili.com/1413433465/