Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCLomatic] Enable migration for CUBLASLT_EPILOGUE_DGELU & EPILOGUE_BGRADB #2449

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions clang/lib/DPCT/RulesMathLib/MapNamesBlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1624,6 +1624,12 @@ void MapNamesBlas::setExplicitNamespaceMap(
{"CUBLASLT_EPILOGUE_RELU",
MapNames::getLibraryHelperNamespace() +
"blas_gemm::experimental::epilogue_t::relu"},
{"CUBLASLT_EPILOGUE_DGELU",
MapNames::getLibraryHelperNamespace() +
"blas_gemm::experimental::epilogue_t::dgelu_epilogue"},
{"CUBLASLT_EPILOGUE_bgradb",
MapNames::getLibraryHelperNamespace() +
"blas_gemm::experimental::epilogue_t::bgradb_epilogue"},
{"CUBLASLT_MATRIX_TRANSFORM_DESC_SCALE_TYPE",
MapNames::getLibraryHelperNamespace() +
"blas_gemm::experimental::transform_desc_t::attribute::scale_type"},
Expand Down
32 changes: 32 additions & 0 deletions clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ class matmul_desc_t {
b_scale_pointer,
d_scale_pointer,
absmax_d_pointer,
dgelu_epilogue,
bgradb_epilogue,
unsupport
};

Expand Down Expand Up @@ -188,6 +190,8 @@ class matmul_desc_t {
CASE(epilogue_aux_ld)
abhilash1910 marked this conversation as resolved.
Show resolved Hide resolved
CASE(epilogue_aux_pointer)
CASE(epilogue_aux_data_type)
CASE(dgelu_epilogue)
CASE(bgradb_epilogue)
default:
break;
}
Expand All @@ -210,6 +214,9 @@ class matmul_desc_t {
void *_absmax_d_pointer = nullptr;
void *_bias_pointer = nullptr;
void *_epilogue_aux_pointer = nullptr;
abhilash1910 marked this conversation as resolved.
Show resolved Hide resolved
typename primitive_type::primitive_desc *_dgelu_epilogue = detail::sync_gelu_backward<::dnnl::eltwise_backward>(0.f, 0.f, new ::dnnl::memory_desc_ext(), ::dnnl::new memory_desc_ext());
typename primitive_type::primitive_desc *_bgradb_epilogue = detail::sync_gelu_backward<::dnnl::reduction>(0.f, 0.f, new ::dnnl::memory_desc_ext(), new ::dnnl::memory_desc_ext());


friend sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr computeDesc,
const void *alpha, const void *a,
Expand All @@ -222,6 +229,31 @@ class matmul_desc_t {

namespace detail {
/// Sacling each row of matrix D with the corresponding element of vector alpha.

template <typename primitive_type, typename... args_type>
inline
typename primitive_type::primitive_desc sync_gelu_backward(
float alpha, float beta, const::dnnl::memory_desc_ext &src_desc,
const ::dnnl::memory_desc_ext &dest_desc) {

auto alg = ::dnnl::algorithm::eltwise_gelu_erf;
return create_primitive_desc<primitive_type>(
::dnnl::prop_kind::backward, alg, src_desc.get_desc(),
dst_desc.get_desc(), alpha, beta);
}

template <typename primitive_type, typename... args_type>
inline
typename primitive_type::primitive_desc bias_backward(
float alpha, float beta, const dnnl::memory_desc_ext &src_desc,
const ::dnnl::memory_desc_ext &dest_desc) {

auto alg = ::dnnl::algorithm::reduction_sum;
return create_primitive_desc<primitive_type>(
::dnnl::prop_kind::backward, alg, src_desc.get_desc(),
dst_desc.get_desc(), alpha, beta);
}
Comment on lines +235 to +255
Copy link
Contributor

@zhiweij1 zhiweij1 Dec 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For backward primitives, it needs a forward primitive object as a constructor argument. https://github.com/oneapi-src/oneDNN/blob/main/include/oneapi/dnnl/dnnl.hpp#L7476


template <class T, class Talpha>
sycl::event scale_d_with_vector_alpha_impl(::dpct::cs::queue_ptr q_ptr,
int rows, int cols, T *d,
Expand Down
4 changes: 4 additions & 0 deletions clang/test/dpct/cublaslt.cu
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ void foo3() {
// CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::epilogue_aux_ld;
// CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::epilogue_aux_pointer;
// CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::epilogue_aux_data_type;
// CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::dgelu_epilogue;
// CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::bgradb_epilogue;
// CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::unsupport;
// CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::unsupport;
// CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::a_scale_pointer;
Expand All @@ -220,6 +222,8 @@ void foo3() {
d = CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD;
d = CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER;
d = CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE;
d = CUBLASLT_EPILOGUE_DGELU;
d = CUBLASLT_EPILOGUE_BGRADB;
d = CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET;
d = CUBLASLT_MATMUL_DESC_FAST_ACCUM;
d = CUBLASLT_MATMUL_DESC_A_SCALE_POINTER;
Expand Down
Loading