-
Notifications
You must be signed in to change notification settings - Fork 91
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SYCLomatic] Enable migration for CUBLASLT_EPILOGUE_DGELU & EPILOGUE_BGRADB #2449
Conversation
float alpha, float beta) { | ||
|
||
auto alg = ::dnnl::algorithm::eltwise_gelu_erf; | ||
const memory_desc_ext &dst_desc = new memory_desc_ext(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where to "delete" dst_desc
and src_desc
?
@@ -44,6 +44,19 @@ using matmul_desc_ptr = matmul_desc_t *; | |||
class transform_desc_t; | |||
using transform_desc_ptr = transform_desc_t *; | |||
|
|||
template <typename primitive_type, typename... args_type> | |||
inline | |||
typename primitive_type::primitive_desc dgelu_epilogue_( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function name looks like a type name, please refine it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
const memory_desc_ext &dst_desc = new memory_desc_ext(); | ||
const memory_desc_ext &src_desc = new memory_desc_ext(); | ||
return create_primitive_desc<primitive_type>( | ||
::dnnl::prop_kind::backward, alg, src_desc.get_desc(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For src_desc
and dst_desc
, is it OK to pass to the create_primitive_desc
directly?
No need to fill data?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please refine the PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please submit E2E test PR first.
typename primitive_type::primitive_desc sync_gelu_backward( | ||
float alpha, float beta, const::dnnl::memory_desc_ext &src_desc, | ||
const ::dnnl::memory_desc_ext &dest_desc) { | ||
|
||
auto alg = ::dnnl::algorithm::eltwise_gelu_erf; | ||
return create_primitive_desc<primitive_type>( | ||
::dnnl::prop_kind::backward, alg, src_desc.get_desc(), | ||
dst_desc.get_desc(), alpha, beta); | ||
} | ||
|
||
template <typename primitive_type, typename... args_type> | ||
inline | ||
typename primitive_type::primitive_desc bias_backward( | ||
float alpha, float beta, const dnnl::memory_desc_ext &src_desc, | ||
const ::dnnl::memory_desc_ext &dest_desc) { | ||
|
||
auto alg = ::dnnl::algorithm::reduction_sum; | ||
return create_primitive_desc<primitive_type>( | ||
::dnnl::prop_kind::backward, alg, src_desc.get_desc(), | ||
dst_desc.get_desc(), alpha, beta); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For backward primitives, it needs a forward primitive object as a constructor argument. https://github.com/oneapi-src/oneDNN/blob/main/include/oneapi/dnnl/dnnl.hpp#L7476
Transfer to @zhiweij1 . Thanks |
@zhiweij1 @zhimingwang36
(WIP)