Skip to content

Commit

Permalink
rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
abhilash1910 committed Nov 29, 2024
1 parent de1cfa3 commit 52f8c8b
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 0 deletions.
6 changes: 6 additions & 0 deletions clang/lib/DPCT/RulesMathLib/MapNamesBlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1624,6 +1624,12 @@ void MapNamesBlas::setExplicitNamespaceMap(
{"CUBLASLT_EPILOGUE_RELU",
MapNames::getLibraryHelperNamespace() +
"blas_gemm::experimental::epilogue_t::relu"},
{"CUBLASLT_EPILOGUE_DGELU",
MapNames::getLibraryHelperNamespace() +
"blas_gemm::experimental::epilogue_t::dgelu_epilogue"},
{"CUBLASLT_EPILOGUE_bgradb",
MapNames::getLibraryHelperNamespace() +
"blas_gemm::experimental::epilogue_t::bgradb_epilogue"},
{"CUBLASLT_MATRIX_TRANSFORM_DESC_SCALE_TYPE",
MapNames::getLibraryHelperNamespace() +
"blas_gemm::experimental::transform_desc_t::attribute::scale_type"},
Expand Down
32 changes: 32 additions & 0 deletions clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ class matmul_desc_t {
b_scale_pointer,
d_scale_pointer,
absmax_d_pointer,
dgelu_epilogue,
bgradb_epilogue,
unsupport
};

Expand Down Expand Up @@ -188,6 +190,8 @@ class matmul_desc_t {
CASE(epilogue_aux_ld)
CASE(epilogue_aux_pointer)
CASE(epilogue_aux_data_type)
CASE(dgelu_epilogue)
CASE(bgradb_epilogue)
default:
break;
}
Expand All @@ -210,6 +214,9 @@ class matmul_desc_t {
void *_absmax_d_pointer = nullptr;
void *_bias_pointer = nullptr;
void *_epilogue_aux_pointer = nullptr;
auto *_dgelu_epilogue = detail::sync_gelu_backward<::dnnl::eltwise_backward>(0.f, 0.f, new memory_desc_ext(), new memory_desc_ext());
auto *_bgradb_epilogue = detail::sync_gelu_backward<::dnnl::reduction>(0.f, 0.f, new memory_desc_ext(), new memory_desc_ext());


friend sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr computeDesc,
const void *alpha, const void *a,
Expand All @@ -222,6 +229,31 @@ class matmul_desc_t {

namespace detail {
/// Sacling each row of matrix D with the corresponding element of vector alpha.

template <typename primitive_type, typename... args_type>
inline
typename primitive_type::primitive_desc sync_gelu_backward(
float alpha, float beta, const memory_desc_ext &src_desc,
const memory_desc_ext &dest_desc) {

auto alg = ::dnnl::algorithm::eltwise_gelu_erf;
return create_primitive_desc<primitive_type>(
::dnnl::prop_kind::backward, alg, src_desc.get_desc(),
dst_desc.get_desc(), alpha, beta);
}

template <typename primitive_type, typename... args_type>
inline
typename primitive_type::primitive_desc bias_backward(
float alpha, float beta, const memory_desc_ext &src_desc,
const memory_desc_ext &dest_desc) {

auto alg = ::dnnl::algorithm::reduction_sum;
return create_primitive_desc<primitive_type>(
::dnnl::prop_kind::backward, alg, src_desc.get_desc(),
dst_desc.get_desc(), alpha, beta);
}

template <class T, class Talpha>
sycl::event scale_d_with_vector_alpha_impl(::dpct::cs::queue_ptr q_ptr,
int rows, int cols, T *d,
Expand Down
4 changes: 4 additions & 0 deletions clang/test/dpct/cublaslt.cu
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ void foo3() {
// CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::epilogue_aux_ld;
// CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::epilogue_aux_pointer;
// CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::epilogue_aux_data_type;
// CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::dgelu_epilogue;
// CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::bgradb_epilogue;
// CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::unsupport;
// CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::unsupport;
// CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::a_scale_pointer;
Expand All @@ -220,6 +222,8 @@ void foo3() {
d = CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD;
d = CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER;
d = CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE;
d = CUBLASLT_EPILOGUE_DGELU;
d = CUBLASLT_EPILOGUE_BGRADB;
d = CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET;
d = CUBLASLT_MATMUL_DESC_FAST_ACCUM;
d = CUBLASLT_MATMUL_DESC_A_SCALE_POINTER;
Expand Down

0 comments on commit 52f8c8b

Please sign in to comment.