diff --git a/tensorflow/core/kernels/xsmm_conv2d.cc b/tensorflow/core/kernels/xsmm_conv2d.cc index c207d4f7a..8bb9afd56 100644 --- a/tensorflow/core/kernels/xsmm_conv2d.cc +++ b/tensorflow/core/kernels/xsmm_conv2d.cc @@ -47,6 +47,81 @@ static void chk_libxsmm_err(libxsmm_dnn_err_t status, string msg) { } } + +class libxsmm_dnn_conv_desc_wrap{ + public: + const libxsmm_dnn_conv_desc d; + + libxsmm_dnn_conv_desc_wrap(const libxsmm_dnn_conv_desc &d_) : d(d_){ + } + bool operator==(const libxsmm_dnn_conv_desc_wrap &w) const{ + return( d.N == w.d.N && + d.C == w.d.C && + d.H == w.d.H && + d.W == w.d.W && + d.K == w.d.K && + d.R == w.d.R && + d.S == w.d.S && + d.u == w.d.u && + d.v == w.d.v && + d.pad_h_in == w.d.pad_h_in && + d.pad_w_in == w.d.pad_w_in + ); + } +}; + + +struct HashFunction{ + std::size_t operator()(const libxsmm_dnn_conv_desc_wrap & w) const{ + std::ostringstream N,C,H,W,K,R,S,u,v,padh,padw; + + N << w.d.N; C << w.d.C; + H << w.d.H; W << w.d.W; + K << w.d.K; R << w.d.R; + S << w.d.S; u << w.d.u; + v << w.d.v; padh << w.d.pad_h_in; + padw << w.d.pad_w_in; + + + std::string out_ = N.str() + C.str()\ + + H.str() + W.str()\ + + K.str() + R.str()\ + + S.str() + u.str()\ + + v.str() + padh.str()\ + + padw.str(); + + return ( std::hash()(out_)); + } +}; + +class handles{ + public: + libxsmm_dnn_conv_handle* find( const libxsmm_dnn_conv_desc_wrap &w) { + std::unordered_map::iterator i = libxsmm_handles.find(w); + if (i == libxsmm_handles.end()){ + libxsmm_dnn_err_t status; + libxsmm_dnn_conv_handle* libxsmm_handle = libxsmm_dnn_create_conv_handle_check(w.d, &status); + chk_libxsmm_err(status, "Create handle"); + libxsmm_handles.insert(std::make_pair(w, libxsmm_handle)); + return libxsmm_handle; + } + else + return i->second; + } + ~handles(){ + std::unordered_map::iterator i; + for (i= libxsmm_handles.begin(); i != libxsmm_handles.end(); i++) + chk_libxsmm_err(libxsmm_dnn_destroy_conv_handle(i->second), + "Destroy handle"); + } + private: + + std::unordered_map libxsmm_handles; + +}; + +static handles libxsmm_handles; + template static bool CallLibxsmmConvGeneric(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc, @@ -54,9 +129,15 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx, FilterPtr filter, OutputPtr output) { libxsmm_dnn_err_t status; libxsmm_dnn_conv_handle* libxsmm_handle; - libxsmm_handle = libxsmm_dnn_create_conv_handle_check(desc, &status); - chk_libxsmm_err(status, "Create handle"); - + libxsmm_dnn_conv_desc_wrap w(desc); + + if(kind == LIBXSMM_DNN_CONV_KIND_FWD) + libxsmm_handle = libxsmm_handles.find(w); + else{ + libxsmm_handle = libxsmm_dnn_create_conv_handle_check(desc, &status); + chk_libxsmm_err(status, "Create handle"); + } + status = libxsmm_dnn_get_codegen_success(libxsmm_handle, kind); if (status == LIBXSMM_DNN_WARN_FALLBACK) { chk_libxsmm_err(libxsmm_dnn_destroy_conv_handle(libxsmm_handle), @@ -110,7 +191,9 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx, chk_libxsmm_err(libxsmm_dnn_destroy_buffer(libxsmm_input), "Destroy input"); chk_libxsmm_err(libxsmm_dnn_destroy_buffer(libxsmm_output), "Destroy output"); chk_libxsmm_err(libxsmm_dnn_destroy_filter(libxsmm_filter), "Destroy filter"); - chk_libxsmm_err(libxsmm_dnn_destroy_conv_handle(libxsmm_handle), + + if(kind != LIBXSMM_DNN_CONV_KIND_FWD) + chk_libxsmm_err(libxsmm_dnn_destroy_conv_handle(libxsmm_handle), "Destroy handle"); return true; // Succeeded