From 77356a0658e56f813ed846c2cf91da42977ec4d4 Mon Sep 17 00:00:00 2001 From: Abhilash Majumder Date: Mon, 2 Dec 2024 04:31:07 -0800 Subject: [PATCH] add epilogue bias mem descriptor and refactor --- .../dpct-rt/include/dpct/blas_gemm_utils.hpp | 50 +++++++++++++++---- 1 file changed, 41 insertions(+), 9 deletions(-) diff --git a/clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp b/clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp index 44ef1a79075f..e034c453df64 100644 --- a/clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp +++ b/clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp @@ -742,6 +742,8 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc, const void *c, matrix_layout_ptr c_desc, void *d, matrix_layout_ptr d_desc, ::dpct::cs::queue_ptr q_ptr) { + matrix_layout_ptr bias_desc = c_desc; + const void *bias = c; const size_t m = compute_desc->_trans_a == oneapi::mkl::transpose::nontrans ? a_desc->_rows : a_desc->_cols; @@ -754,6 +756,7 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc, const library_data_t c_type = c_desc->_type; const library_data_t d_type = d_desc->_type; const library_data_t scale_type = compute_desc->_scale_type; + const library_data_t bias_type = bias_desc->type; if (!q_ptr) q_ptr = &::dpct::cs::get_default_queue(); @@ -820,13 +823,15 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc, const void *new_a = a; const void *new_b = b; const void *new_c = c; + const void *new_bias = c; void *new_d = d; bool new_a_allocated = false; bool new_b_allocated = false; bool new_c_allocated = false; + bool new_bias_allocated = false; bool new_d_allocated = false; size_t new_lda = a_desc->_ld, new_ldb = b_desc->_ld, new_ldc = c_desc->_ld, - new_ldd = d_desc->_ld; + new_ldbias = c_desc->_ld, new_ldd = d_desc->_ld; std::vector transform_events; if (a_desc->_order != order_t::col) { new_lda = a_desc->_rows; @@ -896,6 +901,22 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc, std::vector{}); transform_events.push_back(e); } + + if (bias_desc->_order != order_t::col) { + new_ldbias = bias_desc->_rows; + size_t size_of_element = + dpct::detail::library_data_size[static_cast( + bias_desc->_type)] / + 8; + new_bias = + ::dpct::cs::malloc(size_of_element * bias_desc->_cols * new_ldbias, *q_ptr); + new_bias_allocated = true; + sycl::event e = detail::type_dispatch( + bias_desc->_type, q_ptr, bias_desc->_rows, bias_desc->_cols, bias_desc->_ld, + bias_desc->_order, bias, new_ldbias, order_t::col, const_cast(new_bias), + std::vector{}); + transform_events.push_back(e); + } if (d_desc->_order != order_t::col) { new_ldd = d_desc->_rows; @@ -917,6 +938,7 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc, ::dnnl::memory::dims src_dims = {M, K}; ::dnnl::memory::dims weights_dims = {K, N}; ::dnnl::memory::dims bias_dims = {M, N}; + ::dnnl::memory::dims epilogue_bias_dims = {M, N}; ::dnnl::memory::dims dst_dims = {M, N}; const ::dnnl::memory::dims src_strides = @@ -929,6 +951,8 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc, : ::dnnl::memory::dims{static_cast(new_ldb), 1}; const ::dnnl::memory::dims bias_strides = ::dnnl::memory::dims{1, static_cast(new_ldc)}; + const ::dnnl::memory::dims epilogue_bias_strides = + ::dnnl::memory::dims{1, static_cast(new_ldbias)}; const ::dnnl::memory::dims dst_strides = ::dnnl::memory::dims{1, static_cast(new_ldd)}; @@ -941,6 +965,9 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc, auto bias_md = ::dnnl::memory::desc( bias_dims, dpct::dnnl::memory_desc_ext::to_dnnl_data_type(c_type), bias_strides); + auto epilogue_bias_md = ::dnnl::memory::desc( + epilogue_bias_dims, dpct::dnnl::memory_desc_ext::to_dnnl_data_type(bias_type), + epilogue_bias_strides); auto dst_md = ::dnnl::memory::desc( dst_dims, dpct::dnnl::memory_desc_ext::to_dnnl_data_type(d_type), dst_strides); @@ -951,6 +978,8 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc, new ::dnnl::memory(weights_md, handle->get_engine(), DNNL_MEMORY_NONE); auto *bias_mem = new ::dnnl::memory(bias_md, handle->get_engine(), DNNL_MEMORY_NONE); + auto *epilogue_bias_mem = + new ::dnnl::memory(epilogue_bias_md, handle->get_engine(), DNNL_MEMORY_NONE); auto *dst_mem = new ::dnnl::memory(dst_md, handle->get_engine(), DNNL_MEMORY_NONE); @@ -959,12 +988,14 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc, detail::type_dispatch(b_type, weights_mem, new_b); if (!beta_is_zero) detail::type_dispatch(c_type, bias_mem, new_c); + detail::type_dispatch(bias_type, epilogue_bias_mem, new_bias); detail::type_dispatch(d_type, dst_mem, new_d); #else src_mem->set_data_handle(const_cast(new_a)); weights_mem->set_data_handle(const_cast(new_b)); if (!beta_is_zero) bias_mem->set_data_handle(const_cast(new_c)); + epilogue_bias_mem->set_data_handle(const_cast(new_bias)); dst_mem->set_data_handle(new_d); #endif @@ -973,6 +1004,7 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc, matmul_args.insert({DNNL_ARG_WEIGHTS, *weights_mem}); if (!beta_is_zero) matmul_args.insert({DNNL_ARG_BIAS, *bias_mem}); + matmul_args.insert({DNNL_ARG_BIAS, *epilogue_bias_mem}); matmul_args.insert({DNNL_ARG_DST, *dst_mem}); ::dnnl::primitive_attr matmul_attr; ::dnnl::memory *scales_alpha = nullptr; @@ -1035,19 +1067,19 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc, if (compute_desc->_epilogue == epilogue_t::relu) { matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_relu, 0.f, 0.f); } else if (compute_desc->_epilogue == epilogue_t::gelu) { - matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f); + matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_gelu_tanh, 0.f, 0.f); } else if (compute_desc->_epilogue == epilogue_t::gelu_aux) { - matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f); - dpct::blas::matrix_mem_copy(compute_desc->_epilogue_aux_pointer, new_c, - compute_desc->_epilogue_aux_ld, new_ldc, m, n, + matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_gelu_tanh, 0.f, 0.f); + dpct::blas::matrix_mem_copy(compute_desc->_epilogue_aux_pointer, new_bias, + compute_desc->_epilogue_aux_ld, new_ldbias, m, n, sizeof(size_t) , dpct::device_to_device, queue); } else if (compute_desc->_epilogue == epilogue_t::bias) { matmul_ops.append_binary(::dnnl::algorithm::binary_add, bias_md); } else if (compute_desc->_epilogue == epilogue_t::gelu_aux_bias) { - matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f); - matmul_ops.append_binary(::dnnl::algorithm::binary_add, bias_md); - dpct::blas::matrix_mem_copy(compute_desc->_epilogue_aux_pointer, new_c, - compute_desc->_epilogue_aux_ld, new_ldc, m, n, + matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_gelu_tanh, 0.f, 0.f); + matmul_ops.append_binary(::dnnl::algorithm::binary_add, epilogue_bias_md); + dpct::blas::matrix_mem_copy(compute_desc->_epilogue_aux_pointer, new_bias, + compute_desc->_epilogue_aux_ld, new_ldbias, m, n, sizeof(size_t) , dpct::device_to_device, queue); } matmul_attr.set_post_ops(matmul_ops);