Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCLomatic] Fix issues during migration of wmma functions and types #2035

Merged
merged 14 commits into from
Jun 24, 2024
94 changes: 47 additions & 47 deletions clang/lib/DPCT/APINamesTemplateType.inc
Original file line number Diff line number Diff line change
Expand Up @@ -274,68 +274,68 @@ TYPE_REWRITE_ENTRY("thrust::random::uniform_int_distribution",
TYPE_FACTORY(STR("oneapi::dpl::uniform_int_distribution"), TEMPLATE_ARG(0)))

TYPE_REWRITE_ENTRY("nvcuda::wmma::matrix_a",
TYPE_CONDITIONAL_FACTORY(
checkEnableJointMatrixForType(),
TYPE_FACTORY(STR(MapNames::getClNamespace() +
"ext::oneapi::experimental::matrix::use::a")),
WARNING_FACTORY(Diagnostics::KNOWN_UNSUPPORTED_TYPE, TYPESTR)))
TYPE_CONDITIONAL_FACTORY(
checkEnableJointMatrixForType(),
TYPE_FACTORY(STR(MapNames::getDpctNamespace() +
"experimental::matrix::a")),
WARNING_FACTORY(Diagnostics::KNOWN_UNSUPPORTED_TYPE,
TYPESTR)))

TYPE_REWRITE_ENTRY("nvcuda::wmma::matrix_b",
TYPE_CONDITIONAL_FACTORY(
checkEnableJointMatrixForType(),
TYPE_FACTORY(STR(MapNames::getClNamespace() +
"ext::oneapi::experimental::matrix::use::b")),
WARNING_FACTORY(Diagnostics::KNOWN_UNSUPPORTED_TYPE, TYPESTR)))
TYPE_CONDITIONAL_FACTORY(
checkEnableJointMatrixForType(),
TYPE_FACTORY(STR(MapNames::getDpctNamespace() +
"experimental::matrix::b")),
WARNING_FACTORY(Diagnostics::KNOWN_UNSUPPORTED_TYPE,
TYPESTR)))

TYPE_REWRITE_ENTRY("nvcuda::wmma::row_major",
TYPE_CONDITIONAL_FACTORY(
checkEnableJointMatrixForType(),
TYPE_FACTORY(STR(MapNames::getClNamespace() +
"ext::oneapi::experimental::matrix::layout::row_major")),
WARNING_FACTORY(Diagnostics::KNOWN_UNSUPPORTED_TYPE, TYPESTR)))
TYPE_CONDITIONAL_FACTORY(
checkEnableJointMatrixForType(),
TYPE_FACTORY(STR(MapNames::getDpctNamespace() +
"experimental::matrix::row_major")),
WARNING_FACTORY(Diagnostics::KNOWN_UNSUPPORTED_TYPE,
TYPESTR)))

TYPE_REWRITE_ENTRY("nvcuda::wmma::col_major",
TYPE_CONDITIONAL_FACTORY(
checkEnableJointMatrixForType(),
TYPE_FACTORY(STR(MapNames::getClNamespace() +
"ext::oneapi::experimental::matrix::layout::col_major")),
WARNING_FACTORY(Diagnostics::KNOWN_UNSUPPORTED_TYPE, TYPESTR)))
TYPE_CONDITIONAL_FACTORY(
checkEnableJointMatrixForType(),
TYPE_FACTORY(STR(MapNames::getDpctNamespace() +
"experimental::matrix::col_major")),
WARNING_FACTORY(Diagnostics::KNOWN_UNSUPPORTED_TYPE,
TYPESTR)))

TYPE_REWRITE_ENTRY("nvcuda::wmma::accumulator",
TYPE_CONDITIONAL_FACTORY(
checkEnableJointMatrixForType(),
TYPE_FACTORY(STR(MapNames::getClNamespace() +
"ext::oneapi::experimental::matrix::use::accumulator")),
WARNING_FACTORY(Diagnostics::KNOWN_UNSUPPORTED_TYPE, TYPESTR)))
TYPE_CONDITIONAL_FACTORY(
checkEnableJointMatrixForType(),
TYPE_FACTORY(STR(MapNames::getDpctNamespace() +
"experimental::matrix::accumulator")),
WARNING_FACTORY(Diagnostics::KNOWN_UNSUPPORTED_TYPE,
TYPESTR)))

TYPE_REWRITE_ENTRY("nvcuda::wmma::fragment",
TYPE_REWRITE_ENTRY(
"nvcuda::wmma::fragment",
TYPE_CONDITIONAL_FACTORY(
checkEnableJointMatrixForType(),
TYPE_CONDITIONAL_FACTORY(
CheckTemplateArgCount(5),
TYPE_FACTORY(STR(MapNames::getClNamespace() +
"ext::oneapi::experimental::matrix::joint_matrix"),
STR("sycl::sub_group"),TEMPLATE_ARG(4),TEMPLATE_ARG(0),
TEMPLATE_ARG(1),TEMPLATE_ARG(2)),
TYPE_CONDITIONAL_FACTORY(
checkTemplateArgSpelling(0, "nvcuda::wmma::matrix_a"),
TYPE_FACTORY(STR(MapNames::getClNamespace() +
"ext::oneapi::experimental::matrix::joint_matrix"),
STR("sycl::sub_group"),TEMPLATE_ARG(4),TEMPLATE_ARG(0),
TEMPLATE_ARG(1),TEMPLATE_ARG(3),TEMPLATE_ARG(5)),
TYPE_FACTORY(STR(MapNames::getClNamespace() +
"ext::oneapi::experimental::matrix::joint_matrix"),
STR("sycl::sub_group"),TEMPLATE_ARG(4),TEMPLATE_ARG(0),
TEMPLATE_ARG(2),TEMPLATE_ARG(3),TEMPLATE_ARG(5)))),
TYPE_FACTORY(STR(MapNames::getDpctNamespace() +
"experimental::matrix::joint_matrix"),
TEMPLATE_ARG(0), TEMPLATE_ARG(1), TEMPLATE_ARG(2),
TEMPLATE_ARG(3), TEMPLATE_ARG(4)),
TYPE_FACTORY(STR(MapNames::getDpctNamespace() +
"experimental::matrix::joint_matrix"),
TEMPLATE_ARG(0), TEMPLATE_ARG(1), TEMPLATE_ARG(2),
TEMPLATE_ARG(3), TEMPLATE_ARG(4), TEMPLATE_ARG(5))),
WARNING_FACTORY(Diagnostics::KNOWN_UNSUPPORTED_TYPE, TYPESTR)))

TYPE_REWRITE_ENTRY(
"nvcuda::wmma::layout_t",
TYPE_CONDITIONAL_FACTORY(
checkEnableJointMatrixForType(),
TYPE_FACTORY(STR(MapNames::getClNamespace() +
"ext::oneapi::experimental::matrix::layout")),
WARNING_FACTORY(Diagnostics::KNOWN_UNSUPPORTED_TYPE, TYPESTR)))
TYPE_REWRITE_ENTRY("nvcuda::wmma::layout_t",
TYPE_CONDITIONAL_FACTORY(
checkEnableJointMatrixForType(),
TYPE_FACTORY(STR(MapNames::getDpctNamespace() +
"experimental::matrix::layout_t")),
WARNING_FACTORY(Diagnostics::KNOWN_UNSUPPORTED_TYPE,
TYPESTR)))

// clang-format on

Expand Down
72 changes: 20 additions & 52 deletions clang/lib/DPCT/APINamesWmma.inc
Original file line number Diff line number Diff line change
Expand Up @@ -8,73 +8,41 @@

CONDITIONAL_FACTORY_ENTRY(
checkEnableJointMatrix(),
CALL_FACTORY_ENTRY(
"nvcuda::wmma::fill_fragment",
CALL(MapNames::getClNamespace() +
"ext::oneapi::experimental::matrix::joint_matrix_fill",
SUBGROUP, ARG(0), ARG(1))),
CALL_FACTORY_ENTRY("nvcuda::wmma::fill_fragment",
CALL(MapNames::getDpctNamespace() +
"experimental::matrix::joint_matrix_fill",
ARG(0), ARG(1))),
ENTRY_UNSUPPORTED("nvcuda::wmma::fill_fragment",
Diagnostics::API_NOT_MIGRATED))

CONDITIONAL_FACTORY_ENTRY(
checkEnableJointMatrix(),
CONDITIONAL_FACTORY_ENTRY(
CheckArgCount(3),
CALL_FACTORY_ENTRY(
"nvcuda::wmma::load_matrix_sync",
CALL(MapNames::getClNamespace() +
"ext::oneapi::experimental::matrix::joint_matrix_load",
SUBGROUP, ARG(0),
CALL(TEMPLATED_CALLEE_WITH_ARGS(
MapNames::getClNamespace() + "address_space_cast",
LITERAL(MapNames::getClNamespace() +
"access::address_space::generic_space"),
LITERAL(MapNames::getClNamespace() +
"access::decorated::no"),
getDerefedType(1)),
ARG(1)),
ARG(2))),
CALL_FACTORY_ENTRY(
"nvcuda::wmma::load_matrix_sync",
CALL(MapNames::getClNamespace() +
"ext::oneapi::experimental::matrix::joint_matrix_load",
SUBGROUP, ARG(0),
CALL(TEMPLATED_CALLEE_WITH_ARGS(
MapNames::getClNamespace() + "address_space_cast",
LITERAL(MapNames::getClNamespace() +
"access::address_space::generic_space"),
LITERAL(MapNames::getClNamespace() +
"access::decorated::no"),
getDerefedType(1)),
ARG(1)),
ARG(2), ARG(3)))),
CALL_FACTORY_ENTRY("nvcuda::wmma::load_matrix_sync",
CALL(MapNames::getDpctNamespace() +
"experimental::matrix::joint_matrix_load",
ARG(0), ARG(1), ARG(2))),
CALL_FACTORY_ENTRY("nvcuda::wmma::load_matrix_sync",
CALL(MapNames::getDpctNamespace() +
"experimental::matrix::joint_matrix_load",
ARG(0), ARG(1), ARG(2), ARG(3)))),
ENTRY_UNSUPPORTED("nvcuda::wmma::load_matrix_sync",
Diagnostics::API_NOT_MIGRATED))

CONDITIONAL_FACTORY_ENTRY(
checkEnableJointMatrix(),
CALL_FACTORY_ENTRY(
"nvcuda::wmma::mma_sync",
CALL(MapNames::getClNamespace() +
"ext::oneapi::experimental::matrix::joint_matrix_mad",
SUBGROUP, ARG(0), ARG(1), ARG(2), ARG(3))),
CALL_FACTORY_ENTRY("nvcuda::wmma::mma_sync",
CALL(MapNames::getDpctNamespace() +
"experimental::matrix::joint_matrix_mad",
ARG(0), ARG(1), ARG(2), ARG(3))),
ENTRY_UNSUPPORTED("nvcuda::wmma::mma_sync", Diagnostics::API_NOT_MIGRATED))

CONDITIONAL_FACTORY_ENTRY(
checkEnableJointMatrix(),
CALL_FACTORY_ENTRY(
"nvcuda::wmma::store_matrix_sync",
CALL(MapNames::getClNamespace() +
"ext::oneapi::experimental::matrix::joint_matrix_store",
SUBGROUP, ARG(1),
CALL(TEMPLATED_CALLEE_WITH_ARGS(
MapNames::getClNamespace() + "address_space_cast",
LITERAL(MapNames::getClNamespace() +
"access::address_space::generic_space"),
LITERAL(MapNames::getClNamespace() +
"access::decorated::no"),
getDerefedType(0)),
ARG(0)),
ARG(2), ARG(3))),
CALL_FACTORY_ENTRY("nvcuda::wmma::store_matrix_sync",
CALL(MapNames::getDpctNamespace() +
"experimental::matrix::joint_matrix_store",
ARG(0), ARG(1), ARG(2), ARG(3))),
ENTRY_UNSUPPORTED("nvcuda::wmma::store_matrix_sync",
Diagnostics::API_NOT_MIGRATED))
2 changes: 2 additions & 0 deletions clang/lib/DPCT/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ set(RUNTIME_HEADERS
${CMAKE_SOURCE_DIR}/../clang/runtime/dpct-rt/include/dpct/util.hpp
${CMAKE_SOURCE_DIR}/../clang/runtime/dpct-rt/include/dpct/blas_utils.hpp
${CMAKE_SOURCE_DIR}/../clang/runtime/dpct-rt/include/dpct/dnnl_utils.hpp
${CMAKE_SOURCE_DIR}/../clang/runtime/dpct-rt/include/dpct/mma_utils.hpp
${CMAKE_SOURCE_DIR}/../clang/runtime/dpct-rt/include/dpct/dpct.hpp
${CMAKE_SOURCE_DIR}/../clang/runtime/dpct-rt/include/dpct/dpl_utils.hpp
${CMAKE_SOURCE_DIR}/../clang/runtime/dpct-rt/include/dpct/kernel.hpp
Expand Down Expand Up @@ -45,6 +46,7 @@ set(PROCESS_FILES_OUTPUT
${CMAKE_BINARY_DIR}/tools/clang/include/clang/DPCT/bindless_images.hpp.inc
${CMAKE_BINARY_DIR}/tools/clang/include/clang/DPCT/blas_utils.hpp.inc
${CMAKE_BINARY_DIR}/tools/clang/include/clang/DPCT/dnnl_utils.hpp.inc
${CMAKE_BINARY_DIR}/tools/clang/include/clang/DPCT/mma_utils.hpp.inc
${CMAKE_BINARY_DIR}/tools/clang/include/clang/DPCT/device.hpp.inc
${CMAKE_BINARY_DIR}/tools/clang/include/clang/DPCT/dpct.hpp.inc
${CMAKE_BINARY_DIR}/tools/clang/include/clang/DPCT/dpl_utils.hpp.inc
Expand Down
8 changes: 4 additions & 4 deletions clang/lib/DPCT/CallExprRewriterWmma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ void CallExprRewriterFactoryBase::initRewriterMapWmma() {
EnumConstantRule::EnumNamesMap.insert(
{"nvcuda::wmma::mem_row_major",
std::make_shared<EnumNameRule>(
MapNames::getClNamespace() +
"ext::oneapi::experimental::matrix::layout::row_major")});
MapNames::getDpctNamespace() +
"experimental::matrix::layout_t::m_row_major")});
EnumConstantRule::EnumNamesMap.insert(
{"nvcuda::wmma::mem_col_major",
std::make_shared<EnumNameRule>(
MapNames::getClNamespace() +
"ext::oneapi::experimental::matrix::layout::col_major")});
MapNames::getDpctNamespace() +
"experimental::matrix::layout_t::m_col_major")});
}
RewriterMap->merge(
std::unordered_map<std::string,
Expand Down
4 changes: 4 additions & 0 deletions clang/lib/DPCT/GenHelperFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ const std::string BlasUtilsAllContentStr =
const std::string DnnlUtilsAllContentStr =
#include "clang/DPCT/dnnl_utils.hpp.inc"
;
const std::string MmaUtilsAllContentStr =
#include "clang/DPCT/mma_utils.hpp.inc"
;
const std::string DeviceAllContentStr =
#include "clang/DPCT/device.hpp.inc"
;
Expand Down Expand Up @@ -154,6 +157,7 @@ void genHelperFunction(const clang::tooling::UnifiedPath &OutRoot) {
GENERATE_ALL_FILE_CONTENT(Dpct, ".", dpct.hpp)
GENERATE_ALL_FILE_CONTENT(DplUtils, ".", dpl_utils.hpp)
GENERATE_ALL_FILE_CONTENT(DnnlUtils, ".", dnnl_utils.hpp)
GENERATE_ALL_FILE_CONTENT(MmaUtils, ".", mma_utils.hpp)
GENERATE_ALL_FILE_CONTENT(Image, ".", image.hpp)
GENERATE_ALL_FILE_CONTENT(Kernel, ".", kernel.hpp)
GENERATE_ALL_FILE_CONTENT(Math, ".", math.hpp)
Expand Down
1 change: 1 addition & 0 deletions clang/lib/DPCT/HeaderTypes.inc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ DPCT_HEADER(RNG_Utils, "<dpct/rng_utils.hpp>")
DPCT_HEADER(CCL_Utils, "<dpct/ccl_utils.hpp>")
DPCT_HEADER(BLAS_Utils, "<dpct/blas_utils.hpp>")
DPCT_HEADER(DNNL_Utils, "<dpct/dnnl_utils.hpp>")
DPCT_HEADER(MMA_Utils, "<dpct/mma_utils.hpp>")
DPCT_HEADER(LAPACK_Utils, "<dpct/lapack_utils.hpp>")
DPCT_HEADER(GROUP_Utils, "<dpct/group_utils.hpp>")
DPCT_HEADER(COMMON_Utils, "<dpct/lib_common_utils.hpp>")
Expand Down
3 changes: 3 additions & 0 deletions clang/lib/DPCT/InclusionHeaders.inc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ REGIST_INCLUSION("nccl.h", FullMatch, NCCL, Replace, false,
REGIST_INCLUSION("cudnn.h", FullMatch, DNN, Replace, false,
HeaderType::HT_DPCT_DNNL_Utils)

REGIST_INCLUSION("mma.h", FullMatch, DNN, Replace, false,
intwanghao marked this conversation as resolved.
Show resolved Hide resolved
HeaderType::HT_DPCT_MMA_Utils)

REGIST_INCLUSION("cuda/atomic", FullMatch, Libcu, Replace, false,
HeaderType::HT_DPCT_Atomic)
REGIST_INCLUSION("cuda/std/atomic", FullMatch, Libcu, Replace, false,
Expand Down
1 change: 1 addition & 0 deletions clang/runtime/dpct-rt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ set(dpct_rt_files
include/dpct/util.hpp
include/dpct/blas_utils.hpp
include/dpct/dnnl_utils.hpp
include/dpct/mma_utils.hpp
include/dpct/dpl_utils.hpp
include/dpct/rng_utils.hpp
include/dpct/lib_common_utils.hpp
Expand Down
Loading
Loading