diff --git a/clang/lib/DPCT/RulesMathLib/MapNamesBlas.cpp b/clang/lib/DPCT/RulesMathLib/MapNamesBlas.cpp index 461e660f5056..3d0316731fd4 100644 --- a/clang/lib/DPCT/RulesMathLib/MapNamesBlas.cpp +++ b/clang/lib/DPCT/RulesMathLib/MapNamesBlas.cpp @@ -1624,6 +1624,12 @@ void MapNamesBlas::setExplicitNamespaceMap( {"CUBLASLT_EPILOGUE_RELU", MapNames::getLibraryHelperNamespace() + "blas_gemm::experimental::epilogue_t::relu"}, + {"CUBLASLT_EPILOGUE_DGELU", + MapNames::getLibraryHelperNamespace() + + "blas_gemm::experimental::epilogue_t::dgelu_epilogue"}, + {"CUBLASLT_EPILOGUE_bgradb", + MapNames::getLibraryHelperNamespace() + + "blas_gemm::experimental::epilogue_t::bgradb_epilogue"}, {"CUBLASLT_MATRIX_TRANSFORM_DESC_SCALE_TYPE", MapNames::getLibraryHelperNamespace() + "blas_gemm::experimental::transform_desc_t::attribute::scale_type"}, diff --git a/clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp b/clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp index a9ed88104b15..f0bd111e3f9e 100644 --- a/clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp +++ b/clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp @@ -146,6 +146,8 @@ class matmul_desc_t { b_scale_pointer, d_scale_pointer, absmax_d_pointer, + dgelu_epilogue, + bgradb_epilogue, unsupport }; @@ -188,6 +190,8 @@ class matmul_desc_t { CASE(epilogue_aux_ld) CASE(epilogue_aux_pointer) CASE(epilogue_aux_data_type) + CASE(dgelu_epilogue) + CASE(bgradb_epilogue) default: break; } @@ -210,6 +214,9 @@ class matmul_desc_t { void *_absmax_d_pointer = nullptr; void *_bias_pointer = nullptr; void *_epilogue_aux_pointer = nullptr; + typename primitive_type::primitive_desc *_dgelu_epilogue = detail::sync_gelu_backward<::dnnl::eltwise_backward>(0.f, 0.f, new ::dnnl::memory_desc_ext(), ::dnnl::new memory_desc_ext()); + typename primitive_type::primitive_desc *_bgradb_epilogue = detail::sync_gelu_backward<::dnnl::reduction>(0.f, 0.f, new ::dnnl::memory_desc_ext(), new ::dnnl::memory_desc_ext()); + friend sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr computeDesc, const void *alpha, const void *a, @@ -222,6 +229,31 @@ class matmul_desc_t { namespace detail { /// Sacling each row of matrix D with the corresponding element of vector alpha. + +template +inline +typename primitive_type::primitive_desc sync_gelu_backward( + float alpha, float beta, const::dnnl::memory_desc_ext &src_desc, + const ::dnnl::memory_desc_ext &dest_desc) { + + auto alg = ::dnnl::algorithm::eltwise_gelu_erf; + return create_primitive_desc( + ::dnnl::prop_kind::backward, alg, src_desc.get_desc(), + dst_desc.get_desc(), alpha, beta); +} + +template +inline +typename primitive_type::primitive_desc bias_backward( + float alpha, float beta, const dnnl::memory_desc_ext &src_desc, + const ::dnnl::memory_desc_ext &dest_desc) { + + auto alg = ::dnnl::algorithm::reduction_sum; + return create_primitive_desc( + ::dnnl::prop_kind::backward, alg, src_desc.get_desc(), + dst_desc.get_desc(), alpha, beta); +} + template sycl::event scale_d_with_vector_alpha_impl(::dpct::cs::queue_ptr q_ptr, int rows, int cols, T *d, diff --git a/clang/test/dpct/cublaslt.cu b/clang/test/dpct/cublaslt.cu index 30fb82c05b4d..9ee3b6179a3d 100644 --- a/clang/test/dpct/cublaslt.cu +++ b/clang/test/dpct/cublaslt.cu @@ -197,6 +197,8 @@ void foo3() { // CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::epilogue_aux_ld; // CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::epilogue_aux_pointer; // CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::epilogue_aux_data_type; + // CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::dgelu_epilogue; + // CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::bgradb_epilogue; // CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::unsupport; // CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::unsupport; // CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::a_scale_pointer; @@ -220,6 +222,8 @@ void foo3() { d = CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD; d = CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER; d = CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE; + d = CUBLASLT_EPILOGUE_DGELU; + d = CUBLASLT_EPILOGUE_BGRADB; d = CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET; d = CUBLASLT_MATMUL_DESC_FAST_ACCUM; d = CUBLASLT_MATMUL_DESC_A_SCALE_POINTER;