Skip to content

Commit

Permalink
[SYCLomatic] Fix issues during migration of wmma functions and types (#…
Browse files Browse the repository at this point in the history
…2035)


Signed-off-by: intwanghao <[email protected]>
  • Loading branch information
intwanghao authored Jun 24, 2024
1 parent 9c7be20 commit d0a7e53
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 77 deletions.
80 changes: 40 additions & 40 deletions clang/lib/DPCT/APINamesTemplateType.inc
Original file line number Diff line number Diff line change
Expand Up @@ -274,59 +274,59 @@ TYPE_REWRITE_ENTRY("thrust::random::uniform_int_distribution",
TYPE_FACTORY(STR("oneapi::dpl::uniform_int_distribution"), TEMPLATE_ARG(0)))

TYPE_REWRITE_ENTRY("nvcuda::wmma::matrix_a",
TYPE_CONDITIONAL_FACTORY(
checkEnableJointMatrixForType(),
TYPE_FACTORY(STR(MapNames::getClNamespace() +
"ext::oneapi::experimental::matrix::use::a")),
WARNING_FACTORY(Diagnostics::KNOWN_UNSUPPORTED_TYPE, TYPESTR)))
TYPE_CONDITIONAL_FACTORY(
checkEnableJointMatrixForType(),
TYPE_FACTORY(STR(MapNames::getDpctNamespace() +
"experimental::matrix::a")),
WARNING_FACTORY(Diagnostics::KNOWN_UNSUPPORTED_TYPE,
TYPESTR)))

TYPE_REWRITE_ENTRY("nvcuda::wmma::matrix_b",
TYPE_CONDITIONAL_FACTORY(
checkEnableJointMatrixForType(),
TYPE_FACTORY(STR(MapNames::getClNamespace() +
"ext::oneapi::experimental::matrix::use::b")),
WARNING_FACTORY(Diagnostics::KNOWN_UNSUPPORTED_TYPE, TYPESTR)))
TYPE_CONDITIONAL_FACTORY(
checkEnableJointMatrixForType(),
TYPE_FACTORY(STR(MapNames::getDpctNamespace() +
"experimental::matrix::b")),
WARNING_FACTORY(Diagnostics::KNOWN_UNSUPPORTED_TYPE,
TYPESTR)))

TYPE_REWRITE_ENTRY("nvcuda::wmma::row_major",
TYPE_CONDITIONAL_FACTORY(
checkEnableJointMatrixForType(),
TYPE_FACTORY(STR(MapNames::getClNamespace() +
"ext::oneapi::experimental::matrix::layout::row_major")),
WARNING_FACTORY(Diagnostics::KNOWN_UNSUPPORTED_TYPE, TYPESTR)))
TYPE_CONDITIONAL_FACTORY(
checkEnableJointMatrixForType(),
TYPE_FACTORY(STR(MapNames::getDpctNamespace() +
"experimental::matrix::row_major")),
WARNING_FACTORY(Diagnostics::KNOWN_UNSUPPORTED_TYPE,
TYPESTR)))

TYPE_REWRITE_ENTRY("nvcuda::wmma::col_major",
TYPE_CONDITIONAL_FACTORY(
checkEnableJointMatrixForType(),
TYPE_FACTORY(STR(MapNames::getClNamespace() +
"ext::oneapi::experimental::matrix::layout::col_major")),
WARNING_FACTORY(Diagnostics::KNOWN_UNSUPPORTED_TYPE, TYPESTR)))
TYPE_CONDITIONAL_FACTORY(
checkEnableJointMatrixForType(),
TYPE_FACTORY(STR(MapNames::getDpctNamespace() +
"experimental::matrix::col_major")),
WARNING_FACTORY(Diagnostics::KNOWN_UNSUPPORTED_TYPE,
TYPESTR)))

TYPE_REWRITE_ENTRY("nvcuda::wmma::accumulator",
TYPE_CONDITIONAL_FACTORY(
checkEnableJointMatrixForType(),
TYPE_FACTORY(STR(MapNames::getClNamespace() +
"ext::oneapi::experimental::matrix::use::accumulator")),
WARNING_FACTORY(Diagnostics::KNOWN_UNSUPPORTED_TYPE, TYPESTR)))
TYPE_CONDITIONAL_FACTORY(
checkEnableJointMatrixForType(),
TYPE_FACTORY(STR(MapNames::getDpctNamespace() +
"experimental::matrix::accumulator")),
WARNING_FACTORY(Diagnostics::KNOWN_UNSUPPORTED_TYPE,
TYPESTR)))

TYPE_REWRITE_ENTRY("nvcuda::wmma::fragment",
TYPE_REWRITE_ENTRY(
"nvcuda::wmma::fragment",
TYPE_CONDITIONAL_FACTORY(
checkEnableJointMatrixForType(),
TYPE_CONDITIONAL_FACTORY(
CheckTemplateArgCount(5),
TYPE_FACTORY(STR(MapNames::getClNamespace() +
"ext::oneapi::experimental::matrix::joint_matrix"),
STR("sycl::sub_group"),TEMPLATE_ARG(4),TEMPLATE_ARG(0),
TEMPLATE_ARG(1),TEMPLATE_ARG(2)),
TYPE_CONDITIONAL_FACTORY(
checkTemplateArgSpelling(0, "nvcuda::wmma::matrix_a"),
TYPE_FACTORY(STR(MapNames::getClNamespace() +
"ext::oneapi::experimental::matrix::joint_matrix"),
STR("sycl::sub_group"),TEMPLATE_ARG(4),TEMPLATE_ARG(0),
TEMPLATE_ARG(1),TEMPLATE_ARG(3),TEMPLATE_ARG(5)),
TYPE_FACTORY(STR(MapNames::getClNamespace() +
"ext::oneapi::experimental::matrix::joint_matrix"),
STR("sycl::sub_group"),TEMPLATE_ARG(4),TEMPLATE_ARG(0),
TEMPLATE_ARG(2),TEMPLATE_ARG(3),TEMPLATE_ARG(5)))),
TYPE_FACTORY(STR(MapNames::getDpctNamespace() +
"experimental::matrix::joint_matrix"),
TEMPLATE_ARG(0), TEMPLATE_ARG(1), TEMPLATE_ARG(2),
TEMPLATE_ARG(3), TEMPLATE_ARG(4)),
TYPE_FACTORY(STR(MapNames::getDpctNamespace() +
"experimental::matrix::joint_matrix"),
TEMPLATE_ARG(0), TEMPLATE_ARG(1), TEMPLATE_ARG(2),
TEMPLATE_ARG(3), TEMPLATE_ARG(4), TEMPLATE_ARG(5))),
WARNING_FACTORY(Diagnostics::KNOWN_UNSUPPORTED_TYPE, TYPESTR)))

TYPE_REWRITE_ENTRY(
Expand Down
13 changes: 8 additions & 5 deletions clang/lib/DPCT/APINamesWmma.inc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ CONDITIONAL_FACTORY_ENTRY(
"nvcuda::wmma::fill_fragment",
CALL(MapNames::getClNamespace() +
"ext::oneapi::experimental::matrix::joint_matrix_fill",
SUBGROUP, ARG(0), ARG(1))),
SUBGROUP, MEMBER_CALL(ARG(0), false, "get"), ARG(1))),
ENTRY_UNSUPPORTED("nvcuda::wmma::fill_fragment",
Diagnostics::API_NOT_MIGRATED))

Expand All @@ -24,7 +24,7 @@ CONDITIONAL_FACTORY_ENTRY(
"nvcuda::wmma::load_matrix_sync",
CALL(MapNames::getClNamespace() +
"ext::oneapi::experimental::matrix::joint_matrix_load",
SUBGROUP, ARG(0),
SUBGROUP, MEMBER_CALL(ARG(0), false, "get"),
CALL(TEMPLATED_CALLEE_WITH_ARGS(
MapNames::getClNamespace() + "address_space_cast",
LITERAL(MapNames::getClNamespace() +
Expand All @@ -38,7 +38,7 @@ CONDITIONAL_FACTORY_ENTRY(
"nvcuda::wmma::load_matrix_sync",
CALL(MapNames::getClNamespace() +
"ext::oneapi::experimental::matrix::joint_matrix_load",
SUBGROUP, ARG(0),
SUBGROUP, MEMBER_CALL(ARG(0), false, "get"),
CALL(TEMPLATED_CALLEE_WITH_ARGS(
MapNames::getClNamespace() + "address_space_cast",
LITERAL(MapNames::getClNamespace() +
Expand All @@ -57,7 +57,10 @@ CONDITIONAL_FACTORY_ENTRY(
"nvcuda::wmma::mma_sync",
CALL(MapNames::getClNamespace() +
"ext::oneapi::experimental::matrix::joint_matrix_mad",
SUBGROUP, ARG(0), ARG(1), ARG(2), ARG(3))),
SUBGROUP, MEMBER_CALL(ARG(0), false, "get"),
MEMBER_CALL(ARG(1), false, "get"),
MEMBER_CALL(ARG(2), false, "get"),
MEMBER_CALL(ARG(3), false, "get"))),
ENTRY_UNSUPPORTED("nvcuda::wmma::mma_sync", Diagnostics::API_NOT_MIGRATED))

CONDITIONAL_FACTORY_ENTRY(
Expand All @@ -66,7 +69,7 @@ CONDITIONAL_FACTORY_ENTRY(
"nvcuda::wmma::store_matrix_sync",
CALL(MapNames::getClNamespace() +
"ext::oneapi::experimental::matrix::joint_matrix_store",
SUBGROUP, ARG(1),
SUBGROUP, MEMBER_CALL(ARG(1), false, "get"),
CALL(TEMPLATED_CALLEE_WITH_ARGS(
MapNames::getClNamespace() + "address_space_cast",
LITERAL(MapNames::getClNamespace() +
Expand Down
68 changes: 68 additions & 0 deletions clang/runtime/dpct-rt/include/dpct/math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2135,6 +2135,74 @@ template <typename RetT, typename AT, typename BT>
inline constexpr RetT extend_vavrg4_sat(AT a, BT b, RetT c) {
return detail::extend_vbinary4<RetT, true, false>(a, b, c, detail::average());
}

namespace experimental {
namespace matrix {
namespace syclex = sycl::ext::oneapi::experimental;
struct row_major
: public std::integral_constant<syclex::matrix::layout,
syclex::matrix::layout::row_major> {};
struct col_major
: public std::integral_constant<syclex::matrix::layout,
syclex::matrix::layout::col_major> {};
struct a : public std::integral_constant<syclex::matrix::use,
syclex::matrix::use::a> {};
struct b : public std::integral_constant<syclex::matrix::use,
syclex::matrix::use::b> {};
struct accumulator
: public std::integral_constant<syclex::matrix::use,
syclex::matrix::use::accumulator> {};

template <class use, int m, int n, int k> struct matrix_size_traits;
template <int m, int n, int k> struct matrix_size_traits<a, m, n, k> {
static constexpr int rows = m;
static constexpr int cols = k;
};

template <int m, int n, int k> struct matrix_size_traits<b, m, n, k> {
static constexpr int rows = k;
static constexpr int cols = n;
};

template <int m, int n, int k> struct matrix_size_traits<accumulator, m, n, k> {
static constexpr int rows = m;
static constexpr int cols = n;
};

// A class that wraps the syclex::matrix::joint_matrix class and provides
// copy constructor and assignment operator.
template <typename use, int m, int n, int k, typename T,
typename layout = std::integral_constant<
syclex::matrix::layout, syclex::matrix::layout::dynamic>>
class joint_matrix {
using joint_matrix_type = syclex::matrix::joint_matrix<
sycl::sub_group, T, use::value, matrix_size_traits<use, m, n, k>::rows,
matrix_size_traits<use, m, n, k>::cols, layout::value>;

public:
joint_matrix() : matrix() {}
joint_matrix(joint_matrix &other) {
syclex::matrix::joint_matrix_copy(syclex::this_sub_group(), other.get(),
matrix);
}
joint_matrix &operator=(joint_matrix &other) {
if (this != &other) {
syclex::matrix::joint_matrix_copy(syclex::this_sub_group(), other.get(),
matrix);
}
return *this;
}

joint_matrix_type &get() { return matrix; }

const joint_matrix_type &get() const { return matrix; }

private:
joint_matrix_type matrix;
};
} // namespace matrix
} // namespace experimental

} // namespace dpct

#endif // __DPCT_MATH_HPP__
22 changes: 11 additions & 11 deletions clang/test/dpct/enable-all-experimental-features.cu
Original file line number Diff line number Diff line change
Expand Up @@ -337,17 +337,17 @@ __global__ void simple_wmma_gemm(half *a, half *b, float *c, float *d, int m_ld,
// CHECK: sycl::ext::oneapi::experimental::matrix::layout ly = sycl::ext::oneapi::experimental::matrix::layout::row_major;
nvcuda::wmma::layout_t ly = nvcuda::wmma::mem_row_major;
// Declare the fragments
// CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix<sycl::sub_group, sycl::half, sycl::ext::oneapi::experimental::matrix::use::a, WMMA_M, WMMA_K, sycl::ext::oneapi::experimental::matrix::layout::row_major>
// CHECK: dpct::experimental::matrix::joint_matrix<dpct::experimental::matrix::a, WMMA_M, WMMA_N, WMMA_K, sycl::half, dpct::experimental::matrix::row_major>
nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, nvcuda::wmma::row_major>
a_frag;
// CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix<sycl::sub_group, sycl::half, sycl::ext::oneapi::experimental::matrix::use::b, WMMA_N, WMMA_K, sycl::ext::oneapi::experimental::matrix::layout::col_major>
// CHECK: dpct::experimental::matrix::joint_matrix<dpct::experimental::matrix::b, WMMA_M, WMMA_N, WMMA_K, sycl::half, dpct::experimental::matrix::col_major>
nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, nvcuda::wmma::col_major>
b_frag;
// CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix<sycl::sub_group, float, sycl::ext::oneapi::experimental::matrix::use::accumulator, WMMA_M, WMMA_N> acc_frag;
// CHECK: dpct::experimental::matrix::joint_matrix<dpct::experimental::matrix::accumulator, WMMA_M, WMMA_N, WMMA_K, float> acc_frag;
nvcuda::wmma::fragment<nvcuda::wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> acc_frag;
// CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix<sycl::sub_group, float, sycl::ext::oneapi::experimental::matrix::use::accumulator, WMMA_M, WMMA_N> c_frag;
// CHECK: dpct::experimental::matrix::joint_matrix<dpct::experimental::matrix::accumulator, WMMA_M, WMMA_N, WMMA_K, float> c_frag;
nvcuda::wmma::fragment<nvcuda::wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> c_frag;
// CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_fill(sycl::ext::oneapi::experimental::this_sub_group(), acc_frag, 0.0f);
// CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_fill(sycl::ext::oneapi::experimental::this_sub_group(), acc_frag.get(), 0.0f);
nvcuda::wmma::fill_fragment(acc_frag, 0.0f);

// Loop over k
Expand All @@ -360,13 +360,13 @@ __global__ void simple_wmma_gemm(half *a, half *b, float *c, float *d, int m_ld,
// Bounds checking
if (aRow < m_ld && aCol < k_ld && bRow < k_ld && bCol < n_ld) {
// Load the inputs
// CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_load(sycl::ext::oneapi::experimental::this_sub_group(), a_frag, sycl::address_space_cast<sycl::access::address_space::generic_space, sycl::access::decorated::no, const sycl::half>(a + aCol + aRow * lda), lda);
// CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_load(sycl::ext::oneapi::experimental::this_sub_group(), a_frag.get(), sycl::address_space_cast<sycl::access::address_space::generic_space, sycl::access::decorated::no, const sycl::half>(a + aCol + aRow * lda), lda);
nvcuda::wmma::load_matrix_sync(a_frag, a + aCol + aRow * lda, lda);
// CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_load(sycl::ext::oneapi::experimental::this_sub_group(), b_frag, sycl::address_space_cast<sycl::access::address_space::generic_space, sycl::access::decorated::no, const sycl::half>(b + bRow + bCol * ldb), ldb);
// CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_load(sycl::ext::oneapi::experimental::this_sub_group(), b_frag.get(), sycl::address_space_cast<sycl::access::address_space::generic_space, sycl::access::decorated::no, const sycl::half>(b + bRow + bCol * ldb), ldb);
nvcuda::wmma::load_matrix_sync(b_frag, b + bRow + bCol * ldb, ldb);

// Perform the matrix multiplication
// CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_mad(sycl::ext::oneapi::experimental::this_sub_group(), acc_frag, a_frag, b_frag, acc_frag);
// CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_mad(sycl::ext::oneapi::experimental::this_sub_group(), acc_frag.get(), a_frag.get(), b_frag.get(), acc_frag.get());
nvcuda::wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag);
}
}
Expand All @@ -377,12 +377,12 @@ __global__ void simple_wmma_gemm(half *a, half *b, float *c, float *d, int m_ld,
int cRow = warpM * WMMA_M;

if (cRow < m_ld && cCol < n_ld) {
// CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_load(sycl::ext::oneapi::experimental::this_sub_group(), c_frag, sycl::address_space_cast<sycl::access::address_space::generic_space, sycl::access::decorated::no, const float>(c + cCol + cRow * ldc), ldc, sycl::ext::oneapi::experimental::matrix::layout::row_major);
// CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_load(sycl::ext::oneapi::experimental::this_sub_group(), c_frag.get(), sycl::address_space_cast<sycl::access::address_space::generic_space, sycl::access::decorated::no, const float>(c + cCol + cRow * ldc), ldc, sycl::ext::oneapi::experimental::matrix::layout::row_major);
nvcuda::wmma::load_matrix_sync(c_frag, c + cCol + cRow * ldc, ldc, nvcuda::wmma::mem_row_major);
// CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_load(sycl::ext::oneapi::experimental::this_sub_group(), c_frag, sycl::address_space_cast<sycl::access::address_space::generic_space, sycl::access::decorated::no, const float>(c + cCol + cRow * ldc), ldc, ly);
// CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_load(sycl::ext::oneapi::experimental::this_sub_group(), c_frag.get(), sycl::address_space_cast<sycl::access::address_space::generic_space, sycl::access::decorated::no, const float>(c + cCol + cRow * ldc), ldc, ly);
nvcuda::wmma::load_matrix_sync(c_frag, c + cCol + cRow * ldc, ldc, ly);
// Store the output
// CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_store(sycl::ext::oneapi::experimental::this_sub_group(), c_frag, sycl::address_space_cast<sycl::access::address_space::generic_space, sycl::access::decorated::no, float>(d + cCol + cRow * ldc), ldc, sycl::ext::oneapi::experimental::matrix::layout::col_major);
// CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_store(sycl::ext::oneapi::experimental::this_sub_group(), c_frag.get(), sycl::address_space_cast<sycl::access::address_space::generic_space, sycl::access::decorated::no, float>(d + cCol + cRow * ldc), ldc, sycl::ext::oneapi::experimental::matrix::layout::col_major);
nvcuda::wmma::store_matrix_sync(d + cCol + cRow * ldc, c_frag, ldc, nvcuda::wmma::mem_col_major);
}
}
Expand Down
Loading

0 comments on commit d0a7e53

Please sign in to comment.