Skip to content

Commit

Permalink
Merge pull request #51 from JDAI-CV/fix_non_128_bug
Browse files Browse the repository at this point in the history
Fix non 128 bug
  • Loading branch information
daquexian authored Aug 22, 2019
2 parents a22bb11 + a71711e commit c472484
Show file tree
Hide file tree
Showing 308 changed files with 95 additions and 46 deletions.
6 changes: 6 additions & 0 deletions .daq_pm/configs/run_net_x86
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# It is a configuration file for [project_manager.vim](https://github.com/daquexian/project_manager.vim)
name binary-nn
type cpp
build_dir build_main_x86
cmake_options -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DCMAKE_BUILD_TYPE=Debug -DBNN_BUILD_TEST=OFF -DBNN_BUILD_BENCHMARK=OFF -DBNN_BUILD_MAIN_LIB=ON
target run
5 changes: 5 additions & 0 deletions dabnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ target_include_directories(dabnn
${CMAKE_CURRENT_BINARY_DIR}
${PROJECT_SOURCE_DIR}
)
target_include_directories(dabnn
SYSTEM
PUBLIC
${PROJECT_SOURCE_DIR}/third_party/eigen
)
target_link_libraries(dabnn
glog::glog
flatbuffers
Expand Down
91 changes: 63 additions & 28 deletions dabnn/layers/BinConv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,16 @@ BinConv::BinConv(NetCP net, const std::string &name, css input, css weight,
stride_h(stride_h),
stride_w(stride_w) {
auto &mat_map = net.lock()->mat_map_;
const auto binaized_name = "binaized_for_" + output + "_cal";
if (mat_map.find(binaized_name) == mat_map.end()) {
auto &input_mat = *mat_map[input];
mat_map[binaized_name] =
std::make_shared<Mat>(input_mat.h, input_mat.w, input_mat.elem_c,
DataType::Bit, binaized_name);
if (method() == Method::DIRECT_CONV || method() == Method::BCONV_NAIVE) {
const auto binaized_name = "binaized_for_" + output + "_cal";
if (mat_map.find(binaized_name) == mat_map.end()) {
auto &input_mat = *mat_map[input];
mat_map[binaized_name] = std::make_shared<Mat>(
input_mat.h, input_mat.w, input_mat.elem_c, DataType::Bit,
binaized_name);
}
binarized_mat = mat(binaized_name);
}
binarized_mat = mat(binaized_name);

const auto pad_name = "pad_for_" + output + "_cal";
if (mat_map.find(pad_name) == mat_map.end()) {
Expand All @@ -43,18 +45,17 @@ BinConv::BinConv(NetCP net, const std::string &name, css input, css weight,
}
padded_mat = mat(pad_name);

const auto col_mat_name = "col_for_" + output + "_cal";
if (mat_map.find(col_mat_name) == mat_map.end()) {
const auto len =
output_mat->h * output_mat->w *
align_to(weight_mat->h * weight_mat->w * input_mat->elem_c, 128);
mat_map[col_mat_name] =
std::make_shared<Mat>(1, 1, len, bnn::DataType::Bit);
}
col_mat = mat(col_mat_name);

if (net.lock()->optimize && !direct_conv_compatible() &&
gemm_compatible()) {
if (method() == Method::BGEMM || method() == Method::BGEMM_NAIVE) {
const auto col_mat_name = "col_for_" + output + "_cal";
if (mat_map.find(col_mat_name) == mat_map.end()) {
const auto len =
output_mat->h * output_mat->w *
align_to(weight_mat->h * weight_mat->w * input_mat->elem_c,
128);
mat_map[col_mat_name] =
std::make_shared<Mat>(1, 1, len, bnn::DataType::Bit);
}
col_mat = mat(col_mat_name);
const auto trans_weight_mat_name = "trans_" + weight;
// transpose the weight for bgemm
const int m = weight_mat->n;
Expand All @@ -76,6 +77,24 @@ BinConv::BinConv(NetCP net, const std::string &name, css input, css weight,
}
}

BinConv::Method BinConv::method() const {
if (net_.lock()->optimize) {
if (direct_conv_compatible()) {
return Method::DIRECT_CONV;
} else if (gemm_compatible()) {
return Method::BGEMM;
} else {
return Method::BCONV_NAIVE;
}
} else {
if (weight_mat->c == 1) {
return Method::BCONV_NAIVE;
} else {
return Method::BGEMM_NAIVE;
}
}
}

bool BinConv::direct_conv_compatible() const {
#ifdef __aarch64__
if (weight_mat->h == 3 && weight_mat->w == 3 && input_mat->elem_c == 64 &&
Expand Down Expand Up @@ -121,12 +140,14 @@ bool BinConv::gemm_compatible() const {
}

void BinConv::forward_impl() const {
if (net_.lock()->optimize) {
if (direct_conv_compatible()) {
switch (method()) {
case Method::DIRECT_CONV: {
pack_mat(*input_mat, *binarized_mat);
pad(*binarized_mat, pad_h, pad_w, *padded_mat);
bconv_3x3(*padded_mat, *weight_mat, *output_mat, stride_h);
} else if (gemm_compatible()) {
break;
}
case Method::BGEMM: {
output_mat->fill<float>(0.f);

bnn::fused_binarize_im2col(*input_mat, weight_mat->h, weight_mat->w,
Expand All @@ -139,17 +160,31 @@ void BinConv::forward_impl() const {
bgemm(m, n, k, static_cast<uint64_t *>(transposed_weight_mat->data),
m, static_cast<uint64_t *>(col_mat->data), k,
static_cast<float *>(output_mat->data), m);
} else {
break;
}
case Method::BGEMM_NAIVE: {
output_mat->fill<float>(0.f);

bnn::fused_binarize_im2col(*input_mat, weight_mat->h, weight_mat->w,
pad_h, pad_w, stride_h, stride_w, 1, 1,
*col_mat);

const int m = weight_mat->n;
const int n = output_mat->h * output_mat->w;
const int k = weight_mat->total() / weight_mat->n;
bgemm_naive(m, n, k,
static_cast<uint64_t *>(transposed_weight_mat->data), m,
static_cast<uint64_t *>(col_mat->data), k,
static_cast<float *>(output_mat->data), m);
break;
}
case Method::BCONV_NAIVE: {
pack_mat(*input_mat, *binarized_mat);
baseline_bconv(*binarized_mat, *weight_mat, weight_mat->h,
weight_mat->w, pad_h, pad_w, stride_h, stride_w, 1,
1, output_mat->c, *output_mat);
break;
}
} else {
pack_mat(*input_mat, *binarized_mat);
baseline_bconv(*binarized_mat, *weight_mat, weight_mat->h,
weight_mat->w, pad_h, pad_w, stride_h, stride_w, 1, 1,
output_mat->c, *output_mat);
}
}

Expand Down
7 changes: 7 additions & 0 deletions dabnn/layers/BinConv.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,15 @@ class BinConv : public Layer {
virtual std::string to_str() const;

private:
enum Method {
DIRECT_CONV = 0,
BGEMM,
BCONV_NAIVE,
BGEMM_NAIVE
};
bool direct_conv_compatible() const;
bool gemm_compatible() const;
Method method() const;
};
} // namespace bnn

Expand Down
20 changes: 8 additions & 12 deletions dabnn/mat.h
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,6 @@ inline Mat::Mat(int _n, int _w, int _h, int _c, void *_data, DataType data_type,
", ", h, ", ", c);
}
elemsize = data_type == DataType::Float ? sizeof(float) : sizeof(uint64_t);
BNN_ASSERT(c > 0, c);
std::stringstream ss;
ss << "Not align, w: " << w << ", c: " << c << ", elemsize: " << elemsize;
BNN_ASSERT(!require_align || w * c == 1 || w * c * elemsize % 16 == 0,
Expand All @@ -283,7 +282,10 @@ inline Mat::Mat(int _n, int _w, int _h, int _c, void *_data, DataType data_type,
} else {
hstep = w * c;
}
BNN_ASSERT(hstep > 0, hstep);
if (data_num == 0) {
BNN_ASSERT(c > 0, c);
BNN_ASSERT(hstep > 0, hstep);
}

external_memory = true;
}
Expand Down Expand Up @@ -529,11 +531,6 @@ inline void Mat::create(int _w, int _h, int _c, DataType _data_type) {
h = _h;
c = _c;

if (w * c != 1 && w * c * elemsize % 16 != 0) {
LOG(FATAL) << "Not align, w: " << w << ", c: " << c
<< ", elemsize: " << elemsize;
throw std::invalid_argument("Not align!");
}
hstep = ncnn::alignSize(w * c * elemsize, 16) / elemsize;

if (total() > 0) {
Expand Down Expand Up @@ -563,11 +560,6 @@ inline void Mat::create(int _n, int _w, int _h, int _c, DataType _data_type,
if (h != 0) dims++;
if (c != 0) dims++;

if (require_align && w * c != 1 && w * c * elemsize % 16 != 0) {
LOG(FATAL) << "Not align, w: " << w << ", c: " << c
<< ", elemsize: " << elemsize;
throw std::invalid_argument("Not align!");
}
if (require_align) {
hstep = ncnn::alignSize(w * c * elemsize, 16) / elemsize;
} else {
Expand Down Expand Up @@ -612,24 +604,28 @@ inline size_t Mat::total() const {

template <typename T>
inline const T *Mat::point(int _n, int _h, int _w) const {
BNN_ASSERT(w * c == 1 || w * c * elemsize % 16 == 0, "");
BNN_ASSERT((_n == 0 && _h == 0 && _w == 0) || hstep > 0, hstep);
return (T *)data + _n * h * hstep + _h * hstep + _w * c;
}

template <typename T>
inline const T *Mat::point(int _h, int _w) const {
BNN_ASSERT(w * c == 1 || w * c * elemsize % 16 == 0, "");
BNN_ASSERT((_h == 0 && _w == 0) || hstep > 0, hstep);
return (T *)data + _h * hstep + _w * c;
}

template <typename T>
inline T *Mat::point(int _n, int _h, int _w) {
BNN_ASSERT(w * c == 1 || w * c * elemsize % 16 == 0, "");
BNN_ASSERT((_n == 0 && _h == 0 && _w == 0) || hstep > 0, hstep);
return (T *)data + _n * h * hstep + _h * hstep + _w * c;
}

template <typename T>
inline T *Mat::point(int _h, int _w) {
BNN_ASSERT(w * c == 1 || w * c * elemsize % 16 == 0, "");
BNN_ASSERT((_h == 0 && _w == 0) || hstep > 0, hstep);
return (T *)data + _h * hstep + _w * c;
}
Expand Down
12 changes: 6 additions & 6 deletions tests/net_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ TEST(net, bireal18imagenet_comparison) {
std::shared_ptr<bnn::Mat> blob1, blob2;
{
auto net = bnn::Net::create();
net->read("/data/local/tmp/model_imagenet.dab");
net->optimize = false;
net->read("/data/local/tmp/model_imagenet.dab");
net->run(input);
blob1 = net->get_blob(blob_name);
}
{
auto net = bnn::Net::create();
net->read("/data/local/tmp/model_imagenet.dab");
net->optimize = true;
net->read("/data/local/tmp/model_imagenet.dab");
net->run(input);
blob2 = net->get_blob(blob_name);
}
Expand All @@ -56,8 +56,8 @@ TEST(net, bireal18imagenet) {
const std::string blob_name = "188";
{
auto net = bnn::Net::create();
net->read("/data/local/tmp/model_imagenet.dab");
net->optimize = true;
net->read("/data/local/tmp/model_imagenet.dab");
net->run(input);
const auto blob = net->get_blob(blob_name);
ASSERT_NEAR((*blob)[0], -0.9431, 1e-4);
Expand All @@ -74,15 +74,15 @@ TEST(net, bireal18imagenetstem_comparison) {
std::shared_ptr<bnn::Mat> blob1, blob2;
{
auto net = bnn::Net::create();
net->read("/data/local/tmp/model_imagenet_stem.dab");
net->optimize = false;
net->read("/data/local/tmp/model_imagenet_stem.dab");
net->run(input);
blob1 = net->get_blob(blob_name);
}
{
auto net = bnn::Net::create();
net->read("/data/local/tmp/model_imagenet_stem.dab");
net->optimize = true;
net->read("/data/local/tmp/model_imagenet_stem.dab");
net->run(input);
blob2 = net->get_blob(blob_name);
}
Expand All @@ -96,8 +96,8 @@ TEST(net, bireal18imagenetstem) {
const std::string blob_name = "216";
{
auto net = bnn::Net::create();
net->read("/data/local/tmp/model_imagenet_stem.dab");
net->optimize = true;
net->read("/data/local/tmp/model_imagenet_stem.dab");
net->run(input);
const auto &blob = net->get_blob(blob_name);
ASSERT_NEAR((*blob)[0], 1.9842, 1e-4);
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit c472484

Please sign in to comment.