Skip to content

Commit

Permalink
add gelu
Browse files Browse the repository at this point in the history
  • Loading branch information
abhilash1910 committed Nov 29, 2024
1 parent 9f92c61 commit 1eb5880
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
5 changes: 4 additions & 1 deletion clang/lib/DPCT/RulesMathLib/MapNamesBlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1624,7 +1624,10 @@ void MapNamesBlas::setExplicitNamespaceMap(
{"CUBLASLT_EPILOGUE_RELU",
MapNames::getLibraryHelperNamespace() +
"blas_gemm::experimental::epilogue_t::relu"},
{"CUBLASLT_EPILOGUE_BIAS",
{"CUBLASLT_EPILOGUE_RELU",
MapNames::getLibraryHelperNamespace() +
"blas_gemm::experimental::epilogue_t::gelu"}
{"CUBLASLT_EPILOGUE_GELU",
MapNames::getLibraryHelperNamespace() +
"blas_gemm::experimental::epilogue_t::bias"},
{"CUBLASLT_EPILOGUE_GELU_AUX_BIAS",
Expand Down
9 changes: 6 additions & 3 deletions clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ enum class pointer_mode_t {
alpha_device_vector_beta_zero,
alpha_device_vector_beta_host
};
enum class epilogue_t { nop = 1, relu, bias, gelu_aux_bias };
enum class epilogue_t { nop = 1, relu, bias, gelu, gelu_aux_bias };

class descriptor;
using descriptor_ptr = descriptor *;
Expand Down Expand Up @@ -783,6 +783,7 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc,
if (compute_desc->_epilogue != epilogue_t::nop &&
compute_desc->_epilogue != epilogue_t::relu &&
compute_desc->_epilogue != epilogue_t::bias &&
compute_desc->_epilogue != epilogue_t::gelu &&
compute_desc->_epilogue != epilogue_t::gelu_aux_bias) {
throw std::runtime_error("dpct::blas_gemm::experimental::matmul() only "
"supports relu, gelu, gelu with bias epilogue currently.");
Expand Down Expand Up @@ -1032,14 +1033,16 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc,
::dnnl::post_ops matmul_ops;
if (compute_desc->_epilogue == epilogue_t::relu) {
matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_relu, 0.f, 0.f);
} else if (compute_desc->_epilogue == epilogue_t::bias) {
} else if (compute_desc->_epilogue == epilogue_t::gelu) {
matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f);
} else if (compute_desc->_epilogue == epilogue_t::bias) {
matmul_ops.append_binary(::dnnl::algorithm::binary_add, bias_md);
} else if (compute_desc->_epilogue == epilogue_t::gelu_aux_bias) {
matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f);
matmul_ops.append_binary(::dnnl::algorithm::binary_add, bias_md);
dpct::blas::matrix_mem_copy(compute_desc->_epilogue_aux_pointer, new_c,
compute_desc->_epilogue_aux_ld, new_ldc, m, n,
sizeof(size_t) , q_ptr);
sizeof(size_t) , dpct::device_to_device, q_ptr);
}
matmul_attr.set_post_ops(matmul_ops);
}
Expand Down

0 comments on commit 1eb5880

Please sign in to comment.