Skip to content

Commit

Permalink
add bgradb sample
Browse files Browse the repository at this point in the history
  • Loading branch information
abhilash1910 committed Nov 7, 2024
1 parent caee556 commit a3b3b8c
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 1 deletion.
5 changes: 4 additions & 1 deletion clang/lib/DPCT/MapNames.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2028,7 +2028,10 @@ void MapNames::setExplicitNamespaceMap(
"attribute::epilogue_aux_pointer"},
{"CUBLASLT_EPILOGUE_DGELU",
getLibraryHelperNamespace() + "blas_gemm::experimental::matmul_desc_t::"
"attribute::dgelu_epilogue"},
"attribute::dgelu_epilogue"},
{"CUBLASLT_EPILOGUE_BGRADB",
getLibraryHelperNamespace() + "blas_gemm::experimental::matmul_desc_t::"
"attribute::bgradb_epilogue"},
{"CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET",
getLibraryHelperNamespace() +
"blas_gemm::experimental::matmul_desc_t::attribute::unsupport"},
Expand Down
15 changes: 15 additions & 0 deletions clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ class matmul_desc_t {
epilogue_aux_ld,
epilogue_aux_pointer,
dgelu_epilogue,
bgradb_epilogue,
a_scale_pointer,
b_scale_pointer,
d_scale_pointer,
Expand Down Expand Up @@ -184,6 +185,7 @@ class matmul_desc_t {
CASE(epilogue_aux_ld)
CASE(epilogue_aux_pointer)
CASE(dgelu_epilogue)
CASE(bgradb_epilogue)
default:
break;
}
Expand All @@ -204,6 +206,7 @@ class matmul_desc_t {
void *_absmax_d_pointer = nullptr;
void *_epilogue_aux_pointer = nullptr;
auto *_dgelu_epilogue = detail::sync_gelu_backward<::dnnl::eltwise_backward>(0.f, 0.f, new memory_desc_ext(), new memory_desc_ext());
auto *_bgradb_epilogue = detail::sync_gelu_backward<::dnnl::reduction>(0.f, 0.f, new memory_desc_ext(), new memory_desc_ext());

friend sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr computeDesc,
const void *alpha, const void *a,
Expand All @@ -228,6 +231,18 @@ typename primitive_type::primitive_desc sync_gelu_backward(
dst_desc.get_desc(), alpha, beta);
}

template <typename primitive_type, typename... args_type>
inline
typename primitive_type::primitive_desc bias_backward(
float alpha, float beta, const memory_desc_ext &src_desc,
const memory_desc_ext &dest_desc) {

auto alg = ::dnnl::algorithm::reduction_sum;
return create_primitive_desc<primitive_type>(
::dnnl::prop_kind::backward, alg, src_desc.get_desc(),
dst_desc.get_desc(), alpha, beta);
}

/// Sacling each row of matrix D with the corresponding element of vector alpha.
template <class T, class Talpha>
sycl::event scale_d_with_vector_alpha_impl(::dpct::cs::queue_ptr q_ptr,
Expand Down
2 changes: 2 additions & 0 deletions clang/test/dpct/cublaslt.cu
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ void foo3() {
// CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::epilogue_aux_ld;
// CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::epilogue_aux_pointer;
// CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::dgelu_epilogue;
// CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::bgradb_epilogue;
// CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::unsupport;
// CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::unsupport;
// CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::a_scale_pointer;
Expand All @@ -216,6 +217,7 @@ void foo3() {
d = CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD;
d = CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER;
d = CUBLASLT_EPILOGUE_DGELU;
d = CUBLASLT_EPILOGUE_BGRADB;
d = CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET;
d = CUBLASLT_MATMUL_DESC_FAST_ACCUM;
d = CUBLASLT_MATMUL_DESC_A_SCALE_POINTER;
Expand Down

0 comments on commit a3b3b8c

Please sign in to comment.