From d943972cd0b3f7ae3faf1f0b8e93691497c62745 Mon Sep 17 00:00:00 2001 From: intwanghao Date: Mon, 24 Jun 2024 09:33:32 +0800 Subject: [PATCH] [SYCLomatic] Fix issues during migration of wmma functions and types (#2035) Signed-off-by: intwanghao --- clang/lib/DPCT/APINamesTemplateType.inc | 80 +++++++++---------- clang/lib/DPCT/APINamesWmma.inc | 13 +-- clang/runtime/dpct-rt/include/dpct/math.hpp | 68 ++++++++++++++++ .../dpct/enable-all-experimental-features.cu | 22 ++--- clang/test/dpct/wmma.cu | 22 ++--- clang/test/dpct/wmma_using_nvcuda.cu | 20 ++--- 6 files changed, 148 insertions(+), 77 deletions(-) diff --git a/clang/lib/DPCT/APINamesTemplateType.inc b/clang/lib/DPCT/APINamesTemplateType.inc index 00f9e82119f9..064714cdb53d 100644 --- a/clang/lib/DPCT/APINamesTemplateType.inc +++ b/clang/lib/DPCT/APINamesTemplateType.inc @@ -274,59 +274,59 @@ TYPE_REWRITE_ENTRY("thrust::random::uniform_int_distribution", TYPE_FACTORY(STR("oneapi::dpl::uniform_int_distribution"), TEMPLATE_ARG(0))) TYPE_REWRITE_ENTRY("nvcuda::wmma::matrix_a", - TYPE_CONDITIONAL_FACTORY( - checkEnableJointMatrixForType(), - TYPE_FACTORY(STR(MapNames::getClNamespace() + - "ext::oneapi::experimental::matrix::use::a")), - WARNING_FACTORY(Diagnostics::KNOWN_UNSUPPORTED_TYPE, TYPESTR))) + TYPE_CONDITIONAL_FACTORY( + checkEnableJointMatrixForType(), + TYPE_FACTORY(STR(MapNames::getDpctNamespace() + + "experimental::matrix::a")), + WARNING_FACTORY(Diagnostics::KNOWN_UNSUPPORTED_TYPE, + TYPESTR))) TYPE_REWRITE_ENTRY("nvcuda::wmma::matrix_b", - TYPE_CONDITIONAL_FACTORY( - checkEnableJointMatrixForType(), - TYPE_FACTORY(STR(MapNames::getClNamespace() + - "ext::oneapi::experimental::matrix::use::b")), - WARNING_FACTORY(Diagnostics::KNOWN_UNSUPPORTED_TYPE, TYPESTR))) + TYPE_CONDITIONAL_FACTORY( + checkEnableJointMatrixForType(), + TYPE_FACTORY(STR(MapNames::getDpctNamespace() + + "experimental::matrix::b")), + WARNING_FACTORY(Diagnostics::KNOWN_UNSUPPORTED_TYPE, + TYPESTR))) TYPE_REWRITE_ENTRY("nvcuda::wmma::row_major", - TYPE_CONDITIONAL_FACTORY( - checkEnableJointMatrixForType(), - TYPE_FACTORY(STR(MapNames::getClNamespace() + - "ext::oneapi::experimental::matrix::layout::row_major")), - WARNING_FACTORY(Diagnostics::KNOWN_UNSUPPORTED_TYPE, TYPESTR))) + TYPE_CONDITIONAL_FACTORY( + checkEnableJointMatrixForType(), + TYPE_FACTORY(STR(MapNames::getDpctNamespace() + + "experimental::matrix::row_major")), + WARNING_FACTORY(Diagnostics::KNOWN_UNSUPPORTED_TYPE, + TYPESTR))) TYPE_REWRITE_ENTRY("nvcuda::wmma::col_major", - TYPE_CONDITIONAL_FACTORY( - checkEnableJointMatrixForType(), - TYPE_FACTORY(STR(MapNames::getClNamespace() + - "ext::oneapi::experimental::matrix::layout::col_major")), - WARNING_FACTORY(Diagnostics::KNOWN_UNSUPPORTED_TYPE, TYPESTR))) + TYPE_CONDITIONAL_FACTORY( + checkEnableJointMatrixForType(), + TYPE_FACTORY(STR(MapNames::getDpctNamespace() + + "experimental::matrix::col_major")), + WARNING_FACTORY(Diagnostics::KNOWN_UNSUPPORTED_TYPE, + TYPESTR))) TYPE_REWRITE_ENTRY("nvcuda::wmma::accumulator", - TYPE_CONDITIONAL_FACTORY( - checkEnableJointMatrixForType(), - TYPE_FACTORY(STR(MapNames::getClNamespace() + - "ext::oneapi::experimental::matrix::use::accumulator")), - WARNING_FACTORY(Diagnostics::KNOWN_UNSUPPORTED_TYPE, TYPESTR))) + TYPE_CONDITIONAL_FACTORY( + checkEnableJointMatrixForType(), + TYPE_FACTORY(STR(MapNames::getDpctNamespace() + + "experimental::matrix::accumulator")), + WARNING_FACTORY(Diagnostics::KNOWN_UNSUPPORTED_TYPE, + TYPESTR))) -TYPE_REWRITE_ENTRY("nvcuda::wmma::fragment", +TYPE_REWRITE_ENTRY( + "nvcuda::wmma::fragment", TYPE_CONDITIONAL_FACTORY( checkEnableJointMatrixForType(), TYPE_CONDITIONAL_FACTORY( CheckTemplateArgCount(5), - TYPE_FACTORY(STR(MapNames::getClNamespace() + - "ext::oneapi::experimental::matrix::joint_matrix"), - STR("sycl::sub_group"),TEMPLATE_ARG(4),TEMPLATE_ARG(0), - TEMPLATE_ARG(1),TEMPLATE_ARG(2)), - TYPE_CONDITIONAL_FACTORY( - checkTemplateArgSpelling(0, "nvcuda::wmma::matrix_a"), - TYPE_FACTORY(STR(MapNames::getClNamespace() + - "ext::oneapi::experimental::matrix::joint_matrix"), - STR("sycl::sub_group"),TEMPLATE_ARG(4),TEMPLATE_ARG(0), - TEMPLATE_ARG(1),TEMPLATE_ARG(3),TEMPLATE_ARG(5)), - TYPE_FACTORY(STR(MapNames::getClNamespace() + - "ext::oneapi::experimental::matrix::joint_matrix"), - STR("sycl::sub_group"),TEMPLATE_ARG(4),TEMPLATE_ARG(0), - TEMPLATE_ARG(2),TEMPLATE_ARG(3),TEMPLATE_ARG(5)))), + TYPE_FACTORY(STR(MapNames::getDpctNamespace() + + "experimental::matrix::joint_matrix"), + TEMPLATE_ARG(0), TEMPLATE_ARG(1), TEMPLATE_ARG(2), + TEMPLATE_ARG(3), TEMPLATE_ARG(4)), + TYPE_FACTORY(STR(MapNames::getDpctNamespace() + + "experimental::matrix::joint_matrix"), + TEMPLATE_ARG(0), TEMPLATE_ARG(1), TEMPLATE_ARG(2), + TEMPLATE_ARG(3), TEMPLATE_ARG(4), TEMPLATE_ARG(5))), WARNING_FACTORY(Diagnostics::KNOWN_UNSUPPORTED_TYPE, TYPESTR))) TYPE_REWRITE_ENTRY( diff --git a/clang/lib/DPCT/APINamesWmma.inc b/clang/lib/DPCT/APINamesWmma.inc index 6008f65ddd47..1c9165fc9c78 100644 --- a/clang/lib/DPCT/APINamesWmma.inc +++ b/clang/lib/DPCT/APINamesWmma.inc @@ -12,7 +12,7 @@ CONDITIONAL_FACTORY_ENTRY( "nvcuda::wmma::fill_fragment", CALL(MapNames::getClNamespace() + "ext::oneapi::experimental::matrix::joint_matrix_fill", - SUBGROUP, ARG(0), ARG(1))), + SUBGROUP, MEMBER_CALL(ARG(0), false, "get"), ARG(1))), ENTRY_UNSUPPORTED("nvcuda::wmma::fill_fragment", Diagnostics::API_NOT_MIGRATED)) @@ -24,7 +24,7 @@ CONDITIONAL_FACTORY_ENTRY( "nvcuda::wmma::load_matrix_sync", CALL(MapNames::getClNamespace() + "ext::oneapi::experimental::matrix::joint_matrix_load", - SUBGROUP, ARG(0), + SUBGROUP, MEMBER_CALL(ARG(0), false, "get"), CALL(TEMPLATED_CALLEE_WITH_ARGS( MapNames::getClNamespace() + "address_space_cast", LITERAL(MapNames::getClNamespace() + @@ -38,7 +38,7 @@ CONDITIONAL_FACTORY_ENTRY( "nvcuda::wmma::load_matrix_sync", CALL(MapNames::getClNamespace() + "ext::oneapi::experimental::matrix::joint_matrix_load", - SUBGROUP, ARG(0), + SUBGROUP, MEMBER_CALL(ARG(0), false, "get"), CALL(TEMPLATED_CALLEE_WITH_ARGS( MapNames::getClNamespace() + "address_space_cast", LITERAL(MapNames::getClNamespace() + @@ -57,7 +57,10 @@ CONDITIONAL_FACTORY_ENTRY( "nvcuda::wmma::mma_sync", CALL(MapNames::getClNamespace() + "ext::oneapi::experimental::matrix::joint_matrix_mad", - SUBGROUP, ARG(0), ARG(1), ARG(2), ARG(3))), + SUBGROUP, MEMBER_CALL(ARG(0), false, "get"), + MEMBER_CALL(ARG(1), false, "get"), + MEMBER_CALL(ARG(2), false, "get"), + MEMBER_CALL(ARG(3), false, "get"))), ENTRY_UNSUPPORTED("nvcuda::wmma::mma_sync", Diagnostics::API_NOT_MIGRATED)) CONDITIONAL_FACTORY_ENTRY( @@ -66,7 +69,7 @@ CONDITIONAL_FACTORY_ENTRY( "nvcuda::wmma::store_matrix_sync", CALL(MapNames::getClNamespace() + "ext::oneapi::experimental::matrix::joint_matrix_store", - SUBGROUP, ARG(1), + SUBGROUP, MEMBER_CALL(ARG(1), false, "get"), CALL(TEMPLATED_CALLEE_WITH_ARGS( MapNames::getClNamespace() + "address_space_cast", LITERAL(MapNames::getClNamespace() + diff --git a/clang/runtime/dpct-rt/include/dpct/math.hpp b/clang/runtime/dpct-rt/include/dpct/math.hpp index 96cd45351b43..6aeba145521f 100644 --- a/clang/runtime/dpct-rt/include/dpct/math.hpp +++ b/clang/runtime/dpct-rt/include/dpct/math.hpp @@ -2135,6 +2135,74 @@ template inline constexpr RetT extend_vavrg4_sat(AT a, BT b, RetT c) { return detail::extend_vbinary4(a, b, c, detail::average()); } + +namespace experimental { +namespace matrix { +namespace syclex = sycl::ext::oneapi::experimental; +struct row_major + : public std::integral_constant {}; +struct col_major + : public std::integral_constant {}; +struct a : public std::integral_constant {}; +struct b : public std::integral_constant {}; +struct accumulator + : public std::integral_constant {}; + +template struct matrix_size_traits; +template struct matrix_size_traits { + static constexpr int rows = m; + static constexpr int cols = k; +}; + +template struct matrix_size_traits { + static constexpr int rows = k; + static constexpr int cols = n; +}; + +template struct matrix_size_traits { + static constexpr int rows = m; + static constexpr int cols = n; +}; + +// A class that wraps the syclex::matrix::joint_matrix class and provides +// copy constructor and assignment operator. +template > +class joint_matrix { + using joint_matrix_type = syclex::matrix::joint_matrix< + sycl::sub_group, T, use::value, matrix_size_traits::rows, + matrix_size_traits::cols, layout::value>; + +public: + joint_matrix() : matrix() {} + joint_matrix(joint_matrix &other) { + syclex::matrix::joint_matrix_copy(syclex::this_sub_group(), other.get(), + matrix); + } + joint_matrix &operator=(joint_matrix &other) { + if (this != &other) { + syclex::matrix::joint_matrix_copy(syclex::this_sub_group(), other.get(), + matrix); + } + return *this; + } + + joint_matrix_type &get() { return matrix; } + + const joint_matrix_type &get() const { return matrix; } + +private: + joint_matrix_type matrix; +}; +} // namespace matrix +} // namespace experimental + } // namespace dpct #endif // __DPCT_MATH_HPP__ diff --git a/clang/test/dpct/enable-all-experimental-features.cu b/clang/test/dpct/enable-all-experimental-features.cu index 37bf828351e2..272253b6950e 100644 --- a/clang/test/dpct/enable-all-experimental-features.cu +++ b/clang/test/dpct/enable-all-experimental-features.cu @@ -337,17 +337,17 @@ __global__ void simple_wmma_gemm(half *a, half *b, float *c, float *d, int m_ld, // CHECK: sycl::ext::oneapi::experimental::matrix::layout ly = sycl::ext::oneapi::experimental::matrix::layout::row_major; nvcuda::wmma::layout_t ly = nvcuda::wmma::mem_row_major; // Declare the fragments - // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix + // CHECK: dpct::experimental::matrix::joint_matrix nvcuda::wmma::fragment a_frag; - // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix + // CHECK: dpct::experimental::matrix::joint_matrix nvcuda::wmma::fragment b_frag; - // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix acc_frag; + // CHECK: dpct::experimental::matrix::joint_matrix acc_frag; nvcuda::wmma::fragment acc_frag; - // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix c_frag; + // CHECK: dpct::experimental::matrix::joint_matrix c_frag; nvcuda::wmma::fragment c_frag; - // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_fill(sycl::ext::oneapi::experimental::this_sub_group(), acc_frag, 0.0f); + // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_fill(sycl::ext::oneapi::experimental::this_sub_group(), acc_frag.get(), 0.0f); nvcuda::wmma::fill_fragment(acc_frag, 0.0f); // Loop over k @@ -360,13 +360,13 @@ __global__ void simple_wmma_gemm(half *a, half *b, float *c, float *d, int m_ld, // Bounds checking if (aRow < m_ld && aCol < k_ld && bRow < k_ld && bCol < n_ld) { // Load the inputs - // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_load(sycl::ext::oneapi::experimental::this_sub_group(), a_frag, sycl::address_space_cast(a + aCol + aRow * lda), lda); + // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_load(sycl::ext::oneapi::experimental::this_sub_group(), a_frag.get(), sycl::address_space_cast(a + aCol + aRow * lda), lda); nvcuda::wmma::load_matrix_sync(a_frag, a + aCol + aRow * lda, lda); - // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_load(sycl::ext::oneapi::experimental::this_sub_group(), b_frag, sycl::address_space_cast(b + bRow + bCol * ldb), ldb); + // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_load(sycl::ext::oneapi::experimental::this_sub_group(), b_frag.get(), sycl::address_space_cast(b + bRow + bCol * ldb), ldb); nvcuda::wmma::load_matrix_sync(b_frag, b + bRow + bCol * ldb, ldb); // Perform the matrix multiplication - // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_mad(sycl::ext::oneapi::experimental::this_sub_group(), acc_frag, a_frag, b_frag, acc_frag); + // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_mad(sycl::ext::oneapi::experimental::this_sub_group(), acc_frag.get(), a_frag.get(), b_frag.get(), acc_frag.get()); nvcuda::wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag); } } @@ -377,12 +377,12 @@ __global__ void simple_wmma_gemm(half *a, half *b, float *c, float *d, int m_ld, int cRow = warpM * WMMA_M; if (cRow < m_ld && cCol < n_ld) { - // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_load(sycl::ext::oneapi::experimental::this_sub_group(), c_frag, sycl::address_space_cast(c + cCol + cRow * ldc), ldc, sycl::ext::oneapi::experimental::matrix::layout::row_major); + // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_load(sycl::ext::oneapi::experimental::this_sub_group(), c_frag.get(), sycl::address_space_cast(c + cCol + cRow * ldc), ldc, sycl::ext::oneapi::experimental::matrix::layout::row_major); nvcuda::wmma::load_matrix_sync(c_frag, c + cCol + cRow * ldc, ldc, nvcuda::wmma::mem_row_major); - // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_load(sycl::ext::oneapi::experimental::this_sub_group(), c_frag, sycl::address_space_cast(c + cCol + cRow * ldc), ldc, ly); + // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_load(sycl::ext::oneapi::experimental::this_sub_group(), c_frag.get(), sycl::address_space_cast(c + cCol + cRow * ldc), ldc, ly); nvcuda::wmma::load_matrix_sync(c_frag, c + cCol + cRow * ldc, ldc, ly); // Store the output - // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_store(sycl::ext::oneapi::experimental::this_sub_group(), c_frag, sycl::address_space_cast(d + cCol + cRow * ldc), ldc, sycl::ext::oneapi::experimental::matrix::layout::col_major); + // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_store(sycl::ext::oneapi::experimental::this_sub_group(), c_frag.get(), sycl::address_space_cast(d + cCol + cRow * ldc), ldc, sycl::ext::oneapi::experimental::matrix::layout::col_major); nvcuda::wmma::store_matrix_sync(d + cCol + cRow * ldc, c_frag, ldc, nvcuda::wmma::mem_col_major); } } diff --git a/clang/test/dpct/wmma.cu b/clang/test/dpct/wmma.cu index b411b2dc6e88..95e2df180a8c 100644 --- a/clang/test/dpct/wmma.cu +++ b/clang/test/dpct/wmma.cu @@ -74,17 +74,17 @@ __global__ void simple_wmma_gemm(half *a, half *b, float *c, float *d, int m_ld, // CHECK: sycl::ext::oneapi::experimental::matrix::layout ly = sycl::ext::oneapi::experimental::matrix::layout::row_major; nvcuda::wmma::layout_t ly = nvcuda::wmma::mem_row_major; // Declare the fragments - // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix + // CHECK: dpct::experimental::matrix::joint_matrix nvcuda::wmma::fragment a_frag; - // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix + // CHECK: dpct::experimental::matrix::joint_matrix nvcuda::wmma::fragment b_frag; - // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix acc_frag; + // CHECK: dpct::experimental::matrix::joint_matrix acc_frag; nvcuda::wmma::fragment acc_frag; - // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix c_frag; + // CHECK: dpct::experimental::matrix::joint_matrix c_frag; nvcuda::wmma::fragment c_frag; - // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_fill(item_ct1.get_sub_group(), acc_frag, 0.0f); + // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_fill(item_ct1.get_sub_group(), acc_frag.get(), 0.0f); nvcuda::wmma::fill_fragment(acc_frag, 0.0f); // Loop over k @@ -97,13 +97,13 @@ __global__ void simple_wmma_gemm(half *a, half *b, float *c, float *d, int m_ld, // Bounds checking if (aRow < m_ld && aCol < k_ld && bRow < k_ld && bCol < n_ld) { // Load the inputs - // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_load(item_ct1.get_sub_group(), a_frag, sycl::address_space_cast(a + aCol + aRow * lda), lda); + // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_load(item_ct1.get_sub_group(), a_frag.get(), sycl::address_space_cast(a + aCol + aRow * lda), lda); nvcuda::wmma::load_matrix_sync(a_frag, a + aCol + aRow * lda, lda); - // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_load(item_ct1.get_sub_group(), b_frag, sycl::address_space_cast(b + bRow + bCol * ldb), ldb); + // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_load(item_ct1.get_sub_group(), b_frag.get(), sycl::address_space_cast(b + bRow + bCol * ldb), ldb); nvcuda::wmma::load_matrix_sync(b_frag, b + bRow + bCol * ldb, ldb); // Perform the matrix multiplication - // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_mad(item_ct1.get_sub_group(), acc_frag, a_frag, b_frag, acc_frag); + // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_mad(item_ct1.get_sub_group(), acc_frag.get(), a_frag.get(), b_frag.get(), acc_frag.get()); nvcuda::wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag); } } @@ -114,13 +114,13 @@ __global__ void simple_wmma_gemm(half *a, half *b, float *c, float *d, int m_ld, int cRow = warpM * WMMA_M; if (cRow < m_ld && cCol < n_ld) { - // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_load(item_ct1.get_sub_group(), c_frag, sycl::address_space_cast(c + cCol + cRow * ldc), ldc, sycl::ext::oneapi::experimental::matrix::layout::row_major); + // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_load(item_ct1.get_sub_group(), c_frag.get(), sycl::address_space_cast(c + cCol + cRow * ldc), ldc, sycl::ext::oneapi::experimental::matrix::layout::row_major); nvcuda::wmma::load_matrix_sync(c_frag, c + cCol + cRow * ldc, ldc, nvcuda::wmma::mem_row_major); - // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_load(item_ct1.get_sub_group(), c_frag, sycl::address_space_cast(c + cCol + cRow * ldc), ldc, ly); + // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_load(item_ct1.get_sub_group(), c_frag.get(), sycl::address_space_cast(c + cCol + cRow * ldc), ldc, ly); nvcuda::wmma::load_matrix_sync(c_frag, c + cCol + cRow * ldc, ldc, ly); // Store the output - // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_store(item_ct1.get_sub_group(), c_frag, sycl::address_space_cast(d + cCol + cRow * ldc), ldc, sycl::ext::oneapi::experimental::matrix::layout::col_major); + // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_store(item_ct1.get_sub_group(), c_frag.get(), sycl::address_space_cast(d + cCol + cRow * ldc), ldc, sycl::ext::oneapi::experimental::matrix::layout::col_major); nvcuda::wmma::store_matrix_sync(d + cCol + cRow * ldc, c_frag, ldc, nvcuda::wmma::mem_col_major); } diff --git a/clang/test/dpct/wmma_using_nvcuda.cu b/clang/test/dpct/wmma_using_nvcuda.cu index 04d6f6f58736..dbc2fdf729f8 100644 --- a/clang/test/dpct/wmma_using_nvcuda.cu +++ b/clang/test/dpct/wmma_using_nvcuda.cu @@ -76,17 +76,17 @@ __global__ void simple_wmma_gemm(half *a, half *b, float *c, float *d, int m_ld, int warpN = (blockIdx.y * blockDim.y + threadIdx.y); // Declare the fragments - // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix + // CHECK: dpct::experimental::matrix::joint_matrix wmma::fragment a_frag; - // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix + // CHECK: dpct::experimental::matrix::joint_matrix wmma::fragment b_frag; - // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix acc_frag; + // CHECK: dpct::experimental::matrix::joint_matrix acc_frag; wmma::fragment acc_frag; - // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix c_frag; + // CHECK: dpct::experimental::matrix::joint_matrix c_frag; wmma::fragment c_frag; - // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_fill(item_ct1.get_sub_group(), acc_frag, 0.0f); + // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_fill(item_ct1.get_sub_group(), acc_frag.get(), 0.0f); wmma::fill_fragment(acc_frag, 0.0f); // Loop over k @@ -99,13 +99,13 @@ __global__ void simple_wmma_gemm(half *a, half *b, float *c, float *d, int m_ld, // Bounds checking if (aRow < m_ld && aCol < k_ld && bRow < k_ld && bCol < n_ld) { // Load the inputs - // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_load(item_ct1.get_sub_group(), a_frag, sycl::address_space_cast(a + aCol + aRow * lda), lda); + // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_load(item_ct1.get_sub_group(), a_frag.get(), sycl::address_space_cast(a + aCol + aRow * lda), lda); wmma::load_matrix_sync(a_frag, a + aCol + aRow * lda, lda); - // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_load(item_ct1.get_sub_group(), b_frag, sycl::address_space_cast(b + bRow + bCol * ldb), ldb); + // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_load(item_ct1.get_sub_group(), b_frag.get(), sycl::address_space_cast(b + bRow + bCol * ldb), ldb); wmma::load_matrix_sync(b_frag, b + bRow + bCol * ldb, ldb); // Perform the matrix multiplication - // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_mad(item_ct1.get_sub_group(), acc_frag, a_frag, b_frag, acc_frag); + // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_mad(item_ct1.get_sub_group(), acc_frag.get(), a_frag.get(), b_frag.get(), acc_frag.get()); wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag); } } @@ -116,12 +116,12 @@ __global__ void simple_wmma_gemm(half *a, half *b, float *c, float *d, int m_ld, int cRow = warpM * WMMA_M; if (cRow < m_ld && cCol < n_ld) { - // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_load(item_ct1.get_sub_group(), c_frag, sycl::address_space_cast(c + cCol + cRow * ldc), ldc, sycl::ext::oneapi::experimental::matrix::layout::row_major); + // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_load(item_ct1.get_sub_group(), c_frag.get(), sycl::address_space_cast(c + cCol + cRow * ldc), ldc, sycl::ext::oneapi::experimental::matrix::layout::row_major); wmma::load_matrix_sync(c_frag, c + cCol + cRow * ldc, ldc, wmma::mem_row_major); // Store the output - // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_store(item_ct1.get_sub_group(), c_frag, sycl::address_space_cast(d + cCol + cRow * ldc), ldc, sycl::ext::oneapi::experimental::matrix::layout::row_major); + // CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_store(item_ct1.get_sub_group(), c_frag.get(), sycl::address_space_cast(d + cCol + cRow * ldc), ldc, sycl::ext::oneapi::experimental::matrix::layout::row_major); wmma::store_matrix_sync(d + cCol + cRow * ldc, c_frag, ldc, wmma::mem_row_major); }