Skip to content

Commit

Permalink
add epilogue bias mem descriptor and refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
abhilash1910 committed Dec 2, 2024
1 parent fc018b0 commit 77356a0
Showing 1 changed file with 41 additions and 9 deletions.
50 changes: 41 additions & 9 deletions clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,8 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc,
const void *c, matrix_layout_ptr c_desc, void *d,
matrix_layout_ptr d_desc,
::dpct::cs::queue_ptr q_ptr) {
matrix_layout_ptr bias_desc = c_desc;
const void *bias = c;
const size_t m = compute_desc->_trans_a == oneapi::mkl::transpose::nontrans
? a_desc->_rows
: a_desc->_cols;
Expand All @@ -754,6 +756,7 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc,
const library_data_t c_type = c_desc->_type;
const library_data_t d_type = d_desc->_type;
const library_data_t scale_type = compute_desc->_scale_type;
const library_data_t bias_type = bias_desc->type;

if (!q_ptr)
q_ptr = &::dpct::cs::get_default_queue();
Expand Down Expand Up @@ -820,13 +823,15 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc,
const void *new_a = a;
const void *new_b = b;
const void *new_c = c;
const void *new_bias = c;
void *new_d = d;
bool new_a_allocated = false;
bool new_b_allocated = false;
bool new_c_allocated = false;
bool new_bias_allocated = false;
bool new_d_allocated = false;
size_t new_lda = a_desc->_ld, new_ldb = b_desc->_ld, new_ldc = c_desc->_ld,
new_ldd = d_desc->_ld;
new_ldbias = c_desc->_ld, new_ldd = d_desc->_ld;
std::vector<sycl::event> transform_events;
if (a_desc->_order != order_t::col) {
new_lda = a_desc->_rows;
Expand Down Expand Up @@ -896,6 +901,22 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc,
std::vector<sycl::event>{});
transform_events.push_back(e);
}

if (bias_desc->_order != order_t::col) {
new_ldbias = bias_desc->_rows;
size_t size_of_element =
dpct::detail::library_data_size[static_cast<unsigned int>(
bias_desc->_type)] /
8;
new_bias =
::dpct::cs::malloc(size_of_element * bias_desc->_cols * new_ldbias, *q_ptr);
new_bias_allocated = true;
sycl::event e = detail::type_dispatch<detail::matrix_transform_impl>(
bias_desc->_type, q_ptr, bias_desc->_rows, bias_desc->_cols, bias_desc->_ld,
bias_desc->_order, bias, new_ldbias, order_t::col, const_cast<void *>(new_bias),
std::vector<sycl::event>{});
transform_events.push_back(e);
}

if (d_desc->_order != order_t::col) {
new_ldd = d_desc->_rows;
Expand All @@ -917,6 +938,7 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc,
::dnnl::memory::dims src_dims = {M, K};
::dnnl::memory::dims weights_dims = {K, N};
::dnnl::memory::dims bias_dims = {M, N};
::dnnl::memory::dims epilogue_bias_dims = {M, N};
::dnnl::memory::dims dst_dims = {M, N};

const ::dnnl::memory::dims src_strides =
Expand All @@ -929,6 +951,8 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc,
: ::dnnl::memory::dims{static_cast<long>(new_ldb), 1};
const ::dnnl::memory::dims bias_strides =
::dnnl::memory::dims{1, static_cast<long>(new_ldc)};
const ::dnnl::memory::dims epilogue_bias_strides =
::dnnl::memory::dims{1, static_cast<long>(new_ldbias)};
const ::dnnl::memory::dims dst_strides =
::dnnl::memory::dims{1, static_cast<long>(new_ldd)};

Expand All @@ -941,6 +965,9 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc,
auto bias_md = ::dnnl::memory::desc(
bias_dims, dpct::dnnl::memory_desc_ext::to_dnnl_data_type(c_type),
bias_strides);
auto epilogue_bias_md = ::dnnl::memory::desc(
epilogue_bias_dims, dpct::dnnl::memory_desc_ext::to_dnnl_data_type(bias_type),
epilogue_bias_strides);
auto dst_md = ::dnnl::memory::desc(
dst_dims, dpct::dnnl::memory_desc_ext::to_dnnl_data_type(d_type),
dst_strides);
Expand All @@ -951,6 +978,8 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc,
new ::dnnl::memory(weights_md, handle->get_engine(), DNNL_MEMORY_NONE);
auto *bias_mem =
new ::dnnl::memory(bias_md, handle->get_engine(), DNNL_MEMORY_NONE);
auto *epilogue_bias_mem =
new ::dnnl::memory(epilogue_bias_md, handle->get_engine(), DNNL_MEMORY_NONE);
auto *dst_mem =
new ::dnnl::memory(dst_md, handle->get_engine(), DNNL_MEMORY_NONE);

Expand All @@ -959,12 +988,14 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc,
detail::type_dispatch<detail::set_buffer_impl>(b_type, weights_mem, new_b);
if (!beta_is_zero)
detail::type_dispatch<detail::set_buffer_impl>(c_type, bias_mem, new_c);
detail::type_dispatch<detail::set_buffer_impl>(bias_type, epilogue_bias_mem, new_bias);
detail::type_dispatch<detail::set_buffer_impl>(d_type, dst_mem, new_d);
#else
src_mem->set_data_handle(const_cast<void *>(new_a));
weights_mem->set_data_handle(const_cast<void *>(new_b));
if (!beta_is_zero)
bias_mem->set_data_handle(const_cast<void *>(new_c));
epilogue_bias_mem->set_data_handle(const_cast<void *>(new_bias));
dst_mem->set_data_handle(new_d);
#endif

Expand All @@ -973,6 +1004,7 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc,
matmul_args.insert({DNNL_ARG_WEIGHTS, *weights_mem});
if (!beta_is_zero)
matmul_args.insert({DNNL_ARG_BIAS, *bias_mem});
matmul_args.insert({DNNL_ARG_BIAS, *epilogue_bias_mem});
matmul_args.insert({DNNL_ARG_DST, *dst_mem});
::dnnl::primitive_attr matmul_attr;
::dnnl::memory *scales_alpha = nullptr;
Expand Down Expand Up @@ -1035,19 +1067,19 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc,
if (compute_desc->_epilogue == epilogue_t::relu) {
matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_relu, 0.f, 0.f);
} else if (compute_desc->_epilogue == epilogue_t::gelu) {
matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f);
matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_gelu_tanh, 0.f, 0.f);
} else if (compute_desc->_epilogue == epilogue_t::gelu_aux) {
matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f);
dpct::blas::matrix_mem_copy(compute_desc->_epilogue_aux_pointer, new_c,
compute_desc->_epilogue_aux_ld, new_ldc, m, n,
matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_gelu_tanh, 0.f, 0.f);
dpct::blas::matrix_mem_copy(compute_desc->_epilogue_aux_pointer, new_bias,
compute_desc->_epilogue_aux_ld, new_ldbias, m, n,
sizeof(size_t) , dpct::device_to_device, queue);
} else if (compute_desc->_epilogue == epilogue_t::bias) {
matmul_ops.append_binary(::dnnl::algorithm::binary_add, bias_md);
} else if (compute_desc->_epilogue == epilogue_t::gelu_aux_bias) {
matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f);
matmul_ops.append_binary(::dnnl::algorithm::binary_add, bias_md);
dpct::blas::matrix_mem_copy(compute_desc->_epilogue_aux_pointer, new_c,
compute_desc->_epilogue_aux_ld, new_ldc, m, n,
matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_gelu_tanh, 0.f, 0.f);
matmul_ops.append_binary(::dnnl::algorithm::binary_add, epilogue_bias_md);
dpct::blas::matrix_mem_copy(compute_desc->_epilogue_aux_pointer, new_bias,
compute_desc->_epilogue_aux_ld, new_ldbias, m, n,
sizeof(size_t) , dpct::device_to_device, queue);
}
matmul_attr.set_post_ops(matmul_ops);
Expand Down

0 comments on commit 77356a0

Please sign in to comment.