From caee556e4fe826ed1f300be28c931704fab12128 Mon Sep 17 00:00:00 2001 From: Abhilash Majumder Date: Tue, 5 Nov 2024 23:55:44 -0800 Subject: [PATCH] add descriptor as args --- clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp b/clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp index 5bc3522fa0d3..cdfb8cc3f8d5 100644 --- a/clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp +++ b/clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp @@ -203,7 +203,7 @@ class matmul_desc_t { void *_d_scale_pointer = nullptr; void *_absmax_d_pointer = nullptr; void *_epilogue_aux_pointer = nullptr; - auto *_dgelu_epilogue = detail::sync_gelu_backward<::dnnl::eltwise_backward>(0.f, 0.f); + auto *_dgelu_epilogue = detail::sync_gelu_backward<::dnnl::eltwise_backward>(0.f, 0.f, new memory_desc_ext(), new memory_desc_ext()); friend sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr computeDesc, const void *alpha, const void *a, @@ -219,11 +219,10 @@ namespace detail { template inline typename primitive_type::primitive_desc sync_gelu_backward( - float alpha, float beta) { + float alpha, float beta, const memory_desc_ext &src_desc, + const memory_desc_ext &dest_desc) { auto alg = ::dnnl::algorithm::eltwise_gelu_erf; - const memory_desc_ext &dst_desc = new memory_desc_ext(); - const memory_desc_ext &src_desc = new memory_desc_ext(); return create_primitive_desc( ::dnnl::prop_kind::backward, alg, src_desc.get_desc(), dst_desc.get_desc(), alpha, beta);