Skip to content

Commit

Permalink
Added caching of handles for libxsmm forward convolutions
Browse files Browse the repository at this point in the history
  • Loading branch information
taknevski committed Jan 4, 2017
1 parent f0b2832 commit 804e73d
Showing 1 changed file with 87 additions and 4 deletions.
91 changes: 87 additions & 4 deletions tensorflow/core/kernels/xsmm_conv2d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,97 @@ static void chk_libxsmm_err(libxsmm_dnn_err_t status, string msg) {
}
}


class libxsmm_dnn_conv_desc_wrap{
public:
const libxsmm_dnn_conv_desc d;

libxsmm_dnn_conv_desc_wrap(const libxsmm_dnn_conv_desc &d_) : d(d_){
}
bool operator==(const libxsmm_dnn_conv_desc_wrap &w) const{
return( d.N == w.d.N &&
d.C == w.d.C &&
d.H == w.d.H &&
d.W == w.d.W &&
d.K == w.d.K &&
d.R == w.d.R &&
d.S == w.d.S &&
d.u == w.d.u &&
d.v == w.d.v &&
d.pad_h_in == w.d.pad_h_in &&
d.pad_w_in == w.d.pad_w_in
);
}
};


struct HashFunction{
std::size_t operator()(const libxsmm_dnn_conv_desc_wrap & w) const{
std::ostringstream N,C,H,W,K,R,S,u,v,padh,padw;

N << w.d.N; C << w.d.C;
H << w.d.H; W << w.d.W;
K << w.d.K; R << w.d.R;
S << w.d.S; u << w.d.u;
v << w.d.v; padh << w.d.pad_h_in;
padw << w.d.pad_w_in;


std::string out_ = N.str() + C.str()\
+ H.str() + W.str()\
+ K.str() + R.str()\
+ S.str() + u.str()\
+ v.str() + padh.str()\
+ padw.str();

return ( std::hash<std::string>()(out_));
}
};

class handles{
public:
libxsmm_dnn_conv_handle* find( const libxsmm_dnn_conv_desc_wrap &w) {
std::unordered_map<libxsmm_dnn_conv_desc_wrap , libxsmm_dnn_conv_handle*, HashFunction>::iterator i = libxsmm_handles.find(w);
if (i == libxsmm_handles.end()){
libxsmm_dnn_err_t status;
libxsmm_dnn_conv_handle* libxsmm_handle = libxsmm_dnn_create_conv_handle_check(w.d, &status);
chk_libxsmm_err(status, "Create handle");
libxsmm_handles.insert(std::make_pair(w, libxsmm_handle));
return libxsmm_handle;
}
else
return i->second;
}
~handles(){
std::unordered_map<libxsmm_dnn_conv_desc_wrap , libxsmm_dnn_conv_handle*, HashFunction>::iterator i;
for (i= libxsmm_handles.begin(); i != libxsmm_handles.end(); i++)
chk_libxsmm_err(libxsmm_dnn_destroy_conv_handle(i->second),
"Destroy handle");
}
private:

std::unordered_map<libxsmm_dnn_conv_desc_wrap , libxsmm_dnn_conv_handle*, HashFunction> libxsmm_handles;

};

static handles libxsmm_handles;

template <typename InputPtr, typename FilterPtr, typename OutputPtr>
static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
const libxsmm_dnn_conv_desc& desc,
libxsmm_dnn_conv_kind kind, InputPtr input,
FilterPtr filter, OutputPtr output) {
libxsmm_dnn_err_t status;
libxsmm_dnn_conv_handle* libxsmm_handle;
libxsmm_handle = libxsmm_dnn_create_conv_handle_check(desc, &status);
chk_libxsmm_err(status, "Create handle");

libxsmm_dnn_conv_desc_wrap w(desc);

if(kind == LIBXSMM_DNN_CONV_KIND_FWD)
libxsmm_handle = libxsmm_handles.find(w);
else{
libxsmm_handle = libxsmm_dnn_create_conv_handle_check(desc, &status);
chk_libxsmm_err(status, "Create handle");
}

status = libxsmm_dnn_get_codegen_success(libxsmm_handle, kind);
if (status == LIBXSMM_DNN_WARN_FALLBACK) {
chk_libxsmm_err(libxsmm_dnn_destroy_conv_handle(libxsmm_handle),
Expand Down Expand Up @@ -110,7 +191,9 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
chk_libxsmm_err(libxsmm_dnn_destroy_buffer(libxsmm_input), "Destroy input");
chk_libxsmm_err(libxsmm_dnn_destroy_buffer(libxsmm_output), "Destroy output");
chk_libxsmm_err(libxsmm_dnn_destroy_filter(libxsmm_filter), "Destroy filter");
chk_libxsmm_err(libxsmm_dnn_destroy_conv_handle(libxsmm_handle),

if(kind != LIBXSMM_DNN_CONV_KIND_FWD)
chk_libxsmm_err(libxsmm_dnn_destroy_conv_handle(libxsmm_handle),
"Destroy handle");

return true; // Succeeded
Expand Down

0 comments on commit 804e73d

Please sign in to comment.