Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCLomatic] Fix issues during migration of wmma functions and types #2035

Merged
merged 14 commits into from
Jun 24, 2024
80 changes: 40 additions & 40 deletions clang/lib/DPCT/APINamesTemplateType.inc
Original file line number Diff line number Diff line change
Expand Up @@ -274,59 +274,59 @@ TYPE_REWRITE_ENTRY("thrust::random::uniform_int_distribution",
TYPE_FACTORY(STR("oneapi::dpl::uniform_int_distribution"), TEMPLATE_ARG(0)))

TYPE_REWRITE_ENTRY("nvcuda::wmma::matrix_a",
TYPE_CONDITIONAL_FACTORY(
checkEnableJointMatrixForType(),
TYPE_FACTORY(STR(MapNames::getClNamespace() +
"ext::oneapi::experimental::matrix::use::a")),
WARNING_FACTORY(Diagnostics::KNOWN_UNSUPPORTED_TYPE, TYPESTR)))
TYPE_CONDITIONAL_FACTORY(
checkEnableJointMatrixForType(),
TYPE_FACTORY(STR(MapNames::getDpctNamespace() +
"experimental::matrix::a")),
WARNING_FACTORY(Diagnostics::KNOWN_UNSUPPORTED_TYPE,
TYPESTR)))

TYPE_REWRITE_ENTRY("nvcuda::wmma::matrix_b",
TYPE_CONDITIONAL_FACTORY(
checkEnableJointMatrixForType(),
TYPE_FACTORY(STR(MapNames::getClNamespace() +
"ext::oneapi::experimental::matrix::use::b")),
WARNING_FACTORY(Diagnostics::KNOWN_UNSUPPORTED_TYPE, TYPESTR)))
TYPE_CONDITIONAL_FACTORY(
checkEnableJointMatrixForType(),
TYPE_FACTORY(STR(MapNames::getDpctNamespace() +
"experimental::matrix::b")),
WARNING_FACTORY(Diagnostics::KNOWN_UNSUPPORTED_TYPE,
TYPESTR)))

TYPE_REWRITE_ENTRY("nvcuda::wmma::row_major",
TYPE_CONDITIONAL_FACTORY(
checkEnableJointMatrixForType(),
TYPE_FACTORY(STR(MapNames::getClNamespace() +
"ext::oneapi::experimental::matrix::layout::row_major")),
WARNING_FACTORY(Diagnostics::KNOWN_UNSUPPORTED_TYPE, TYPESTR)))
TYPE_CONDITIONAL_FACTORY(
checkEnableJointMatrixForType(),
TYPE_FACTORY(STR(MapNames::getDpctNamespace() +
"experimental::matrix::row_major")),
WARNING_FACTORY(Diagnostics::KNOWN_UNSUPPORTED_TYPE,
TYPESTR)))

TYPE_REWRITE_ENTRY("nvcuda::wmma::col_major",
TYPE_CONDITIONAL_FACTORY(
checkEnableJointMatrixForType(),
TYPE_FACTORY(STR(MapNames::getClNamespace() +
"ext::oneapi::experimental::matrix::layout::col_major")),
WARNING_FACTORY(Diagnostics::KNOWN_UNSUPPORTED_TYPE, TYPESTR)))
TYPE_CONDITIONAL_FACTORY(
checkEnableJointMatrixForType(),
TYPE_FACTORY(STR(MapNames::getDpctNamespace() +
"experimental::matrix::col_major")),
WARNING_FACTORY(Diagnostics::KNOWN_UNSUPPORTED_TYPE,
TYPESTR)))

TYPE_REWRITE_ENTRY("nvcuda::wmma::accumulator",
TYPE_CONDITIONAL_FACTORY(
checkEnableJointMatrixForType(),
TYPE_FACTORY(STR(MapNames::getClNamespace() +
"ext::oneapi::experimental::matrix::use::accumulator")),
WARNING_FACTORY(Diagnostics::KNOWN_UNSUPPORTED_TYPE, TYPESTR)))
TYPE_CONDITIONAL_FACTORY(
checkEnableJointMatrixForType(),
TYPE_FACTORY(STR(MapNames::getDpctNamespace() +
"experimental::matrix::accumulator")),
WARNING_FACTORY(Diagnostics::KNOWN_UNSUPPORTED_TYPE,
TYPESTR)))

TYPE_REWRITE_ENTRY("nvcuda::wmma::fragment",
TYPE_REWRITE_ENTRY(
"nvcuda::wmma::fragment",
TYPE_CONDITIONAL_FACTORY(
checkEnableJointMatrixForType(),
TYPE_CONDITIONAL_FACTORY(
CheckTemplateArgCount(5),
TYPE_FACTORY(STR(MapNames::getClNamespace() +
"ext::oneapi::experimental::matrix::joint_matrix"),
STR("sycl::sub_group"),TEMPLATE_ARG(4),TEMPLATE_ARG(0),
TEMPLATE_ARG(1),TEMPLATE_ARG(2)),
TYPE_CONDITIONAL_FACTORY(
checkTemplateArgSpelling(0, "nvcuda::wmma::matrix_a"),
TYPE_FACTORY(STR(MapNames::getClNamespace() +
"ext::oneapi::experimental::matrix::joint_matrix"),
STR("sycl::sub_group"),TEMPLATE_ARG(4),TEMPLATE_ARG(0),
TEMPLATE_ARG(1),TEMPLATE_ARG(3),TEMPLATE_ARG(5)),
TYPE_FACTORY(STR(MapNames::getClNamespace() +
"ext::oneapi::experimental::matrix::joint_matrix"),
STR("sycl::sub_group"),TEMPLATE_ARG(4),TEMPLATE_ARG(0),
TEMPLATE_ARG(2),TEMPLATE_ARG(3),TEMPLATE_ARG(5)))),
TYPE_FACTORY(STR(MapNames::getDpctNamespace() +
"experimental::matrix::joint_matrix"),
TEMPLATE_ARG(0), TEMPLATE_ARG(1), TEMPLATE_ARG(2),
TEMPLATE_ARG(3), TEMPLATE_ARG(4)),
TYPE_FACTORY(STR(MapNames::getDpctNamespace() +
"experimental::matrix::joint_matrix"),
TEMPLATE_ARG(0), TEMPLATE_ARG(1), TEMPLATE_ARG(2),
TEMPLATE_ARG(3), TEMPLATE_ARG(4), TEMPLATE_ARG(5))),
WARNING_FACTORY(Diagnostics::KNOWN_UNSUPPORTED_TYPE, TYPESTR)))

TYPE_REWRITE_ENTRY(
Expand Down
13 changes: 8 additions & 5 deletions clang/lib/DPCT/APINamesWmma.inc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ CONDITIONAL_FACTORY_ENTRY(
"nvcuda::wmma::fill_fragment",
CALL(MapNames::getClNamespace() +
"ext::oneapi::experimental::matrix::joint_matrix_fill",
SUBGROUP, ARG(0), ARG(1))),
SUBGROUP, MEMBER_CALL(ARG(0), false, "get"), ARG(1))),
ENTRY_UNSUPPORTED("nvcuda::wmma::fill_fragment",
Diagnostics::API_NOT_MIGRATED))

Expand All @@ -24,7 +24,7 @@ CONDITIONAL_FACTORY_ENTRY(
"nvcuda::wmma::load_matrix_sync",
CALL(MapNames::getClNamespace() +
"ext::oneapi::experimental::matrix::joint_matrix_load",
SUBGROUP, ARG(0),
SUBGROUP, MEMBER_CALL(ARG(0), false, "get"),
CALL(TEMPLATED_CALLEE_WITH_ARGS(
MapNames::getClNamespace() + "address_space_cast",
LITERAL(MapNames::getClNamespace() +
Expand All @@ -38,7 +38,7 @@ CONDITIONAL_FACTORY_ENTRY(
"nvcuda::wmma::load_matrix_sync",
CALL(MapNames::getClNamespace() +
"ext::oneapi::experimental::matrix::joint_matrix_load",
SUBGROUP, ARG(0),
SUBGROUP, MEMBER_CALL(ARG(0), false, "get"),
CALL(TEMPLATED_CALLEE_WITH_ARGS(
MapNames::getClNamespace() + "address_space_cast",
LITERAL(MapNames::getClNamespace() +
Expand All @@ -57,7 +57,10 @@ CONDITIONAL_FACTORY_ENTRY(
"nvcuda::wmma::mma_sync",
CALL(MapNames::getClNamespace() +
"ext::oneapi::experimental::matrix::joint_matrix_mad",
SUBGROUP, ARG(0), ARG(1), ARG(2), ARG(3))),
SUBGROUP, MEMBER_CALL(ARG(0), false, "get"),
MEMBER_CALL(ARG(1), false, "get"),
MEMBER_CALL(ARG(2), false, "get"),
MEMBER_CALL(ARG(3), false, "get"))),
ENTRY_UNSUPPORTED("nvcuda::wmma::mma_sync", Diagnostics::API_NOT_MIGRATED))

CONDITIONAL_FACTORY_ENTRY(
Expand All @@ -66,7 +69,7 @@ CONDITIONAL_FACTORY_ENTRY(
"nvcuda::wmma::store_matrix_sync",
CALL(MapNames::getClNamespace() +
"ext::oneapi::experimental::matrix::joint_matrix_store",
SUBGROUP, ARG(1),
SUBGROUP, MEMBER_CALL(ARG(1), false, "get"),
CALL(TEMPLATED_CALLEE_WITH_ARGS(
MapNames::getClNamespace() + "address_space_cast",
LITERAL(MapNames::getClNamespace() +
Expand Down
68 changes: 68 additions & 0 deletions clang/runtime/dpct-rt/include/dpct/math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2135,6 +2135,74 @@ template <typename RetT, typename AT, typename BT>
inline constexpr RetT extend_vavrg4_sat(AT a, BT b, RetT c) {
return detail::extend_vbinary4<RetT, true, false>(a, b, c, detail::average());
}

namespace experimental {
namespace matrix {
namespace syclex = sycl::ext::oneapi::experimental;
struct row_major
: public std::integral_constant<syclex::matrix::layout,
syclex::matrix::layout::row_major> {};
struct col_major
: public std::integral_constant<syclex::matrix::layout,
syclex::matrix::layout::col_major> {};
struct a : public std::integral_constant<syclex::matrix::use,
syclex::matrix::use::a> {};
struct b : public std::integral_constant<syclex::matrix::use,
syclex::matrix::use::b> {};
struct accumulator
: public std::integral_constant<syclex::matrix::use,
syclex::matrix::use::accumulator> {};

template <class use, int m, int n, int k> struct matrix_size_traits;
template <int m, int n, int k> struct matrix_size_traits<a, m, n, k> {
static constexpr int rows = m;
static constexpr int cols = k;
};

template <int m, int n, int k> struct matrix_size_traits<b, m, n, k> {
static constexpr int rows = k;
static constexpr int cols = n;
};

template <int m, int n, int k> struct matrix_size_traits<accumulator, m, n, k> {
static constexpr int rows = m;
static constexpr int cols = n;
};

// A class that wraps the syclex::matrix::joint_matrix class and provides
// copy constructor and assignment operator.
template <typename use, int m, int n, int k, typename T,
typename layout = std::integral_constant<
syclex::matrix::layout, syclex::matrix::layout::dynamic>>
class joint_matrix {
using joint_matrix_type = syclex::matrix::joint_matrix<
sycl::sub_group, T, use::value, matrix_size_traits<use, m, n, k>::rows,
matrix_size_traits<use, m, n, k>::cols, layout::value>;

public:
joint_matrix() : matrix() {}
joint_matrix(joint_matrix &other) {
syclex::matrix::joint_matrix_copy(syclex::this_sub_group(), other.get(),
matrix);
}
joint_matrix &operator=(joint_matrix &other) {
if (this != &other) {
syclex::matrix::joint_matrix_copy(syclex::this_sub_group(), other.get(),
matrix);
}
return *this;
}

joint_matrix_type &get() { return matrix; }

const joint_matrix_type &get() const { return matrix; }

private:
joint_matrix_type matrix;
};
} // namespace matrix
} // namespace experimental

} // namespace dpct

#endif // __DPCT_MATH_HPP__
22 changes: 11 additions & 11 deletions clang/test/dpct/enable-all-experimental-features.cu
Original file line number Diff line number Diff line change
Expand Up @@ -337,17 +337,17 @@ __global__ void simple_wmma_gemm(half *a, half *b, float *c, float *d, int m_ld,
// CHECK: sycl::ext::oneapi::experimental::matrix::layout ly = sycl::ext::oneapi::experimental::matrix::layout::row_major;
nvcuda::wmma::layout_t ly = nvcuda::wmma::mem_row_major;
// Declare the fragments
// CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix<sycl::sub_group, sycl::half, sycl::ext::oneapi::experimental::matrix::use::a, WMMA_M, WMMA_K, sycl::ext::oneapi::experimental::matrix::layout::row_major>
// CHECK: dpct::experimental::matrix::joint_matrix<dpct::experimental::matrix::a, WMMA_M, WMMA_N, WMMA_K, sycl::half, dpct::experimental::matrix::row_major>
nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, nvcuda::wmma::row_major>
a_frag;
// CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix<sycl::sub_group, sycl::half, sycl::ext::oneapi::experimental::matrix::use::b, WMMA_N, WMMA_K, sycl::ext::oneapi::experimental::matrix::layout::col_major>
// CHECK: dpct::experimental::matrix::joint_matrix<dpct::experimental::matrix::b, WMMA_M, WMMA_N, WMMA_K, sycl::half, dpct::experimental::matrix::col_major>
nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, nvcuda::wmma::col_major>
b_frag;
// CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix<sycl::sub_group, float, sycl::ext::oneapi::experimental::matrix::use::accumulator, WMMA_M, WMMA_N> acc_frag;
// CHECK: dpct::experimental::matrix::joint_matrix<dpct::experimental::matrix::accumulator, WMMA_M, WMMA_N, WMMA_K, float> acc_frag;
nvcuda::wmma::fragment<nvcuda::wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> acc_frag;
// CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix<sycl::sub_group, float, sycl::ext::oneapi::experimental::matrix::use::accumulator, WMMA_M, WMMA_N> c_frag;
// CHECK: dpct::experimental::matrix::joint_matrix<dpct::experimental::matrix::accumulator, WMMA_M, WMMA_N, WMMA_K, float> c_frag;
nvcuda::wmma::fragment<nvcuda::wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> c_frag;
// CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_fill(sycl::ext::oneapi::experimental::this_sub_group(), acc_frag, 0.0f);
// CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_fill(sycl::ext::oneapi::experimental::this_sub_group(), acc_frag.get(), 0.0f);
nvcuda::wmma::fill_fragment(acc_frag, 0.0f);

// Loop over k
Expand All @@ -360,13 +360,13 @@ __global__ void simple_wmma_gemm(half *a, half *b, float *c, float *d, int m_ld,
// Bounds checking
if (aRow < m_ld && aCol < k_ld && bRow < k_ld && bCol < n_ld) {
// Load the inputs
// CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_load(sycl::ext::oneapi::experimental::this_sub_group(), a_frag, sycl::address_space_cast<sycl::access::address_space::generic_space, sycl::access::decorated::no, const sycl::half>(a + aCol + aRow * lda), lda);
// CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_load(sycl::ext::oneapi::experimental::this_sub_group(), a_frag.get(), sycl::address_space_cast<sycl::access::address_space::generic_space, sycl::access::decorated::no, const sycl::half>(a + aCol + aRow * lda), lda);
nvcuda::wmma::load_matrix_sync(a_frag, a + aCol + aRow * lda, lda);
// CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_load(sycl::ext::oneapi::experimental::this_sub_group(), b_frag, sycl::address_space_cast<sycl::access::address_space::generic_space, sycl::access::decorated::no, const sycl::half>(b + bRow + bCol * ldb), ldb);
// CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_load(sycl::ext::oneapi::experimental::this_sub_group(), b_frag.get(), sycl::address_space_cast<sycl::access::address_space::generic_space, sycl::access::decorated::no, const sycl::half>(b + bRow + bCol * ldb), ldb);
nvcuda::wmma::load_matrix_sync(b_frag, b + bRow + bCol * ldb, ldb);

// Perform the matrix multiplication
// CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_mad(sycl::ext::oneapi::experimental::this_sub_group(), acc_frag, a_frag, b_frag, acc_frag);
// CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_mad(sycl::ext::oneapi::experimental::this_sub_group(), acc_frag.get(), a_frag.get(), b_frag.get(), acc_frag.get());
nvcuda::wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag);
}
}
Expand All @@ -377,12 +377,12 @@ __global__ void simple_wmma_gemm(half *a, half *b, float *c, float *d, int m_ld,
int cRow = warpM * WMMA_M;

if (cRow < m_ld && cCol < n_ld) {
// CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_load(sycl::ext::oneapi::experimental::this_sub_group(), c_frag, sycl::address_space_cast<sycl::access::address_space::generic_space, sycl::access::decorated::no, const float>(c + cCol + cRow * ldc), ldc, sycl::ext::oneapi::experimental::matrix::layout::row_major);
// CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_load(sycl::ext::oneapi::experimental::this_sub_group(), c_frag.get(), sycl::address_space_cast<sycl::access::address_space::generic_space, sycl::access::decorated::no, const float>(c + cCol + cRow * ldc), ldc, sycl::ext::oneapi::experimental::matrix::layout::row_major);
nvcuda::wmma::load_matrix_sync(c_frag, c + cCol + cRow * ldc, ldc, nvcuda::wmma::mem_row_major);
// CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_load(sycl::ext::oneapi::experimental::this_sub_group(), c_frag, sycl::address_space_cast<sycl::access::address_space::generic_space, sycl::access::decorated::no, const float>(c + cCol + cRow * ldc), ldc, ly);
// CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_load(sycl::ext::oneapi::experimental::this_sub_group(), c_frag.get(), sycl::address_space_cast<sycl::access::address_space::generic_space, sycl::access::decorated::no, const float>(c + cCol + cRow * ldc), ldc, ly);
nvcuda::wmma::load_matrix_sync(c_frag, c + cCol + cRow * ldc, ldc, ly);
// Store the output
// CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_store(sycl::ext::oneapi::experimental::this_sub_group(), c_frag, sycl::address_space_cast<sycl::access::address_space::generic_space, sycl::access::decorated::no, float>(d + cCol + cRow * ldc), ldc, sycl::ext::oneapi::experimental::matrix::layout::col_major);
// CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_store(sycl::ext::oneapi::experimental::this_sub_group(), c_frag.get(), sycl::address_space_cast<sycl::access::address_space::generic_space, sycl::access::decorated::no, float>(d + cCol + cRow * ldc), ldc, sycl::ext::oneapi::experimental::matrix::layout::col_major);
nvcuda::wmma::store_matrix_sync(d + cCol + cRow * ldc, c_frag, ldc, nvcuda::wmma::mem_col_major);
}
}
Expand Down
Loading
Loading