Skip to content

Commit

Permalink
add descriptor as args
Browse files Browse the repository at this point in the history
  • Loading branch information
abhilash1910 committed Nov 6, 2024
1 parent 962bb09 commit caee556
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ class matmul_desc_t {
void *_d_scale_pointer = nullptr;
void *_absmax_d_pointer = nullptr;
void *_epilogue_aux_pointer = nullptr;
auto *_dgelu_epilogue = detail::sync_gelu_backward<::dnnl::eltwise_backward>(0.f, 0.f);
auto *_dgelu_epilogue = detail::sync_gelu_backward<::dnnl::eltwise_backward>(0.f, 0.f, new memory_desc_ext(), new memory_desc_ext());

friend sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr computeDesc,
const void *alpha, const void *a,
Expand All @@ -219,11 +219,10 @@ namespace detail {
template <typename primitive_type, typename... args_type>
inline
typename primitive_type::primitive_desc sync_gelu_backward(
float alpha, float beta) {
float alpha, float beta, const memory_desc_ext &src_desc,
const memory_desc_ext &dest_desc) {

auto alg = ::dnnl::algorithm::eltwise_gelu_erf;
const memory_desc_ext &dst_desc = new memory_desc_ext();
const memory_desc_ext &src_desc = new memory_desc_ext();
return create_primitive_desc<primitive_type>(
::dnnl::prop_kind::backward, alg, src_desc.get_desc(),
dst_desc.get_desc(), alpha, beta);
Expand Down

0 comments on commit caee556

Please sign in to comment.