Skip to content

Commit

Permalink
[SYCLomatic] Support the migration of matmul & transform in cublasLt (#…
Browse files Browse the repository at this point in the history
…1993)

introduce new helper function file blas_gemm.hpp, which is based on oneDNN API.  

Signed-off-by: Jiang, Zhiwei <[email protected]>
  • Loading branch information
zhiweij1 authored Jun 26, 2024
1 parent 221d8ef commit dcd88af
Show file tree
Hide file tree
Showing 16 changed files with 1,436 additions and 67 deletions.
61 changes: 61 additions & 0 deletions clang/lib/DPCT/APINamesCUBLAS.inc
Original file line number Diff line number Diff line change
Expand Up @@ -1634,3 +1634,64 @@ GEMM_BATCH(cublasDgemmStridedBatched, "double", false)
GEMM_BATCH(cublasCgemmStridedBatched, "std::complex<float>", true)
GEMM_BATCH(cublasZgemmStridedBatched, "std::complex<double>", true)
#undef GEMM_BATCH

ASSIGNABLE_FACTORY(ASSIGN_FACTORY_ENTRY(
"cublasLtCreate", DEREF(0),
NEW(MapNames::getDpctNamespace() + "blas_gemm::experimental::descriptor")))
ASSIGNABLE_FACTORY(DELETE_FACTORY_ENTRY("cublasLtDestroy", ARG(0)))
ASSIGNABLE_FACTORY(ASSIGN_FACTORY_ENTRY(
"cublasLtMatmulDescCreate", DEREF(0),
NEW(MapNames::getDpctNamespace() + "blas_gemm::experimental::matmul_desc_t",
ARG(1), ARG(2))))
ASSIGNABLE_FACTORY(DELETE_FACTORY_ENTRY("cublasLtMatmulDescDestroy", ARG(0)))
ASSIGNABLE_FACTORY(MEMBER_CALL_FACTORY_ENTRY("cublasLtMatmulDescSetAttribute",
ARG(0), true, "set_attribute",
ARG(1), ARG(2)))
ASSIGNABLE_FACTORY(MEMBER_CALL_FACTORY_ENTRY("cublasLtMatmulDescGetAttribute",
ARG(0), true, "get_attribute",
ARG(1), ARG(2)))

ASSIGNABLE_FACTORY(
ASSIGN_FACTORY_ENTRY("cublasLtMatrixLayoutCreate", DEREF(0),
NEW(MapNames::getDpctNamespace() +
"blas_gemm::experimental::matrix_layout_t",
ARG(1), ARG(2), ARG(3), ARG(4))))
ASSIGNABLE_FACTORY(DELETE_FACTORY_ENTRY("cublasLtMatrixLayoutDestroy", ARG(0)))
ASSIGNABLE_FACTORY(MEMBER_CALL_FACTORY_ENTRY("cublasLtMatrixLayoutSetAttribute",
ARG(0), true, "set_attribute",
ARG(1), ARG(2)))
ASSIGNABLE_FACTORY(MEMBER_CALL_FACTORY_ENTRY("cublasLtMatrixLayoutGetAttribute",
ARG(0), true, "get_attribute",
ARG(1), ARG(2)))

ASSIGNABLE_FACTORY(CALL_FACTORY_ENTRY(
"cublasLtMatmul",
CALL(MapNames::getDpctNamespace() + "blas_gemm::experimental::matmul",
ARG(0), ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), ARG(7), ARG(8),
ARG(9), ARG(10), ARG(11), ARG(15))))
REMOVE_API_FACTORY_ENTRY("cublasLtMatmulPreferenceCreate")
REMOVE_API_FACTORY_ENTRY("cublasLtMatmulPreferenceDestroy")
REMOVE_API_FACTORY_ENTRY("cublasLtMatmulPreferenceSetAttribute")
REMOVE_API_FACTORY_ENTRY("cublasLtMatmulPreferenceGetAttribute")
ASSIGNABLE_FACTORY(ASSIGN_FACTORY_ENTRY("cublasLtMatmulAlgoGetHeuristic",
DEREF(9), ARG("1")))

ASSIGNABLE_FACTORY(
ASSIGN_FACTORY_ENTRY("cublasLtMatrixTransformDescCreate", DEREF(0),
NEW(MapNames::getDpctNamespace() +
"blas_gemm::experimental::transform_desc_t",
ARG(1))))
ASSIGNABLE_FACTORY(DELETE_FACTORY_ENTRY("cublasLtMatrixTransformDescDestroy",
ARG(0)))
ASSIGNABLE_FACTORY(
MEMBER_CALL_FACTORY_ENTRY("cublasLtMatrixTransformDescSetAttribute", ARG(0),
true, "set_attribute", ARG(1), ARG(2)))
ASSIGNABLE_FACTORY(
MEMBER_CALL_FACTORY_ENTRY("cublasLtMatrixTransformDescGetAttribute", ARG(0),
true, "get_attribute", ARG(1), ARG(2)))
ASSIGNABLE_FACTORY(
CALL_FACTORY_ENTRY("cublasLtMatrixTransform",
CALL(MapNames::getDpctNamespace() +
"blas_gemm::experimental::matrix_transform",
ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), ARG(6),
ARG(7), ARG(8), ARG(9), ARG(10))))
42 changes: 21 additions & 21 deletions clang/lib/DPCT/APINames_cuBLAS.inc
Original file line number Diff line number Diff line change
Expand Up @@ -287,31 +287,31 @@ ENTRY(cublasNrm2Ex, cublasNrm2Ex, true, NO_FLAG, P4, "DPCT1020")
ENTRY(cublasAxpyEx, cublasAxpyEx, true, NO_FLAG, P4, "DPCT1020")
ENTRY(cublasDotEx, cublasDotEx, true, NO_FLAG, P4, "DPCT1020")
ENTRY(cublasScalEx, cublasScalEx, true, NO_FLAG, P4, "DPCT1020")
ENTRY(cublasLtCreate, cublasLtCreate, false, NO_FLAG, P4, "comment")
ENTRY(cublasLtDestroy, cublasLtDestroy, false, NO_FLAG, P4, "comment")
ENTRY(cublasLtCreate, cublasLtCreate, true, NO_FLAG, P4, "DPCT1026/DPCT1027")
ENTRY(cublasLtDestroy, cublasLtDestroy, true, NO_FLAG, P4, "DPCT1026/DPCT1027")
ENTRY(cublasLtGetVersion, cublasLtGetVersion, false, NO_FLAG, P4, "comment")
ENTRY(cublasLtGetCudartVersion, cublasLtGetCudartVersion, false, NO_FLAG, P4, "comment")
ENTRY(cublasLtGetProperty, cublasLtGetProperty, false, NO_FLAG, P4, "comment")
ENTRY(cublasLtMatmul, cublasLtMatmul, false, NO_FLAG, P4, "comment")
ENTRY(cublasLtMatrixLayoutCreate, cublasLtMatrixLayoutCreate, false, NO_FLAG, P4, "comment")
ENTRY(cublasLtMatrixLayoutGetAttribute, cublasLtMatrixLayoutGetAttribute, false, NO_FLAG, P4, "comment")
ENTRY(cublasLtMatrixLayoutSetAttribute, cublasLtMatrixLayoutSetAttribute, false, NO_FLAG, P4, "comment")
ENTRY(cublasLtMatrixLayoutDestroy, cublasLtMatrixLayoutDestroy, false, NO_FLAG, P4, "comment")
ENTRY(cublasLtMatrixTransform, cublasLtMatrixTransform, false, NO_FLAG, P4, "comment")
ENTRY(cublasLtMatmulDescCreate, cublasLtMatmulDescCreate, false, NO_FLAG, P4, "comment")
ENTRY(cublasLtMatmulDescDestroy, cublasLtMatmulDescDestroy, false, NO_FLAG, P4, "comment")
ENTRY(cublasLtMatmulDescSetAttribute, cublasLtMatmulDescSetAttribute, false, NO_FLAG, P4, "comment")
ENTRY(cublasLtMatmulDescGetAttribute, cublasLtMatmulDescGetAttribute, false, NO_FLAG, P4, "comment")
ENTRY(cublasLtMatrixTransformDescCreate, cublasLtMatrixTransformDescCreate, false, NO_FLAG, P4, "comment")
ENTRY(cublasLtMatrixTransformDescDestroy, cublasLtMatrixTransformDescDestroy, false, NO_FLAG, P4, "comment")
ENTRY(cublasLtMatrixTransformDescSetAttribute, cublasLtMatrixTransformDescSetAttribute, false, NO_FLAG, P4, "comment")
ENTRY(cublasLtMatrixTransformDescGetAttribute, cublasLtMatrixTransformDescGetAttribute, false, NO_FLAG, P4, "comment")
ENTRY(cublasLtMatmulPreferenceCreate, cublasLtMatmulPreferenceCreate, false, NO_FLAG, P4, "comment")
ENTRY(cublasLtMatmulPreferenceDestroy, cublasLtMatmulPreferenceDestroy, false, NO_FLAG, P4, "comment")
ENTRY(cublasLtMatmulPreferenceSetAttribute, cublasLtMatmulPreferenceSetAttribute, false, NO_FLAG, P4, "comment")
ENTRY(cublasLtMatmulPreferenceGetAttribute, cublasLtMatmulPreferenceGetAttribute, false, NO_FLAG, P4, "comment")
ENTRY(cublasLtMatmul, cublasLtMatmul, true, NO_FLAG, P4, "comment")
ENTRY(cublasLtMatrixLayoutCreate, cublasLtMatrixLayoutCreate, true, NO_FLAG, P4, "comment")
ENTRY(cublasLtMatrixLayoutGetAttribute, cublasLtMatrixLayoutGetAttribute, true, NO_FLAG, P4, "comment")
ENTRY(cublasLtMatrixLayoutSetAttribute, cublasLtMatrixLayoutSetAttribute, true, NO_FLAG, P4, "comment")
ENTRY(cublasLtMatrixLayoutDestroy, cublasLtMatrixLayoutDestroy, true, NO_FLAG, P4, "comment")
ENTRY(cublasLtMatrixTransform, cublasLtMatrixTransform, true, NO_FLAG, P4, "comment")
ENTRY(cublasLtMatmulDescCreate, cublasLtMatmulDescCreate, true, NO_FLAG, P4, "comment")
ENTRY(cublasLtMatmulDescDestroy, cublasLtMatmulDescDestroy, true, NO_FLAG, P4, "comment")
ENTRY(cublasLtMatmulDescSetAttribute, cublasLtMatmulDescSetAttribute, true, NO_FLAG, P4, "comment")
ENTRY(cublasLtMatmulDescGetAttribute, cublasLtMatmulDescGetAttribute, true, NO_FLAG, P4, "comment")
ENTRY(cublasLtMatrixTransformDescCreate, cublasLtMatrixTransformDescCreate, true, NO_FLAG, P4, "comment")
ENTRY(cublasLtMatrixTransformDescDestroy, cublasLtMatrixTransformDescDestroy, true, NO_FLAG, P4, "comment")
ENTRY(cublasLtMatrixTransformDescSetAttribute, cublasLtMatrixTransformDescSetAttribute, true, NO_FLAG, P4, "comment")
ENTRY(cublasLtMatrixTransformDescGetAttribute, cublasLtMatrixTransformDescGetAttribute, true, NO_FLAG, P4, "comment")
ENTRY(cublasLtMatmulPreferenceCreate, cublasLtMatmulPreferenceCreate, true, NO_FLAG, P4, "comment")
ENTRY(cublasLtMatmulPreferenceDestroy, cublasLtMatmulPreferenceDestroy, true, NO_FLAG, P4, "comment")
ENTRY(cublasLtMatmulPreferenceSetAttribute, cublasLtMatmulPreferenceSetAttribute, true, NO_FLAG, P4, "comment")
ENTRY(cublasLtMatmulPreferenceGetAttribute, cublasLtMatmulPreferenceGetAttribute, true, NO_FLAG, P4, "comment")
ENTRY(cublasLtMatmulAlgoCheck, cublasLtMatmulAlgoCheck, false, NO_FLAG, P4, "comment")
ENTRY(cublasLtMatmulAlgoGetHeuristic, cublasLtMatmulAlgoGetHeuristic, false, NO_FLAG, P4, "comment")
ENTRY(cublasLtMatmulAlgoGetHeuristic, cublasLtMatmulAlgoGetHeuristic, true, NO_FLAG, P4, "comment")
ENTRY(cublasLtMatmulAlgoGetIds, cublasLtMatmulAlgoGetIds, false, NO_FLAG, P4, "comment")
ENTRY(cublasLtMatmulAlgoInit, cublasLtMatmulAlgoInit, false, NO_FLAG, P4, "comment")
ENTRY(cublasLtMatmulAlgoCapGetAttribute, cublasLtMatmulAlgoCapGetAttribute, false, NO_FLAG, P4, "comment")
Expand Down
32 changes: 28 additions & 4 deletions clang/lib/DPCT/ASTTraversal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1738,7 +1738,14 @@ void TypeInDeclRule::registerMatcher(MatchFinder &MF) {
"cudaLaunchAttributeValue", "cusparseSpSMDescr_t",
"cusparseConstSpMatDescr_t", "cusparseSpSMAlg_t",
"cusparseConstDnMatDescr_t", "cudaMemcpy3DParms", "CUDA_MEMCPY3D",
"CUDA_MEMCPY2D", "CUDA_ARRAY_DESCRIPTOR"))))))
"CUDA_MEMCPY2D", "CUDA_ARRAY_DESCRIPTOR", "cublasLtHandle_t",
"cublasLtMatmulDesc_t", "cublasLtOrder_t",
"cublasLtPointerMode_t", "cublasLtMatrixLayout_t",
"cublasLtMatrixLayoutAttribute_t",
"cublasLtMatmulDescAttributes_t", "cublasLtMatmulAlgo_t",
"cublasLtEpilogue_t", "cublasLtMatmulPreference_t",
"cublasLtMatmulHeuristicResult_t",
"cublasLtMatrixTransformDesc_t"))))))
.bind("cudaTypeDef"),
this);

Expand Down Expand Up @@ -3605,14 +3612,18 @@ REGISTER_RULE(CU_JITEnumsRule, PassKind::PK_Migration)
void BLASEnumsRule::registerMatcher(MatchFinder &MF) {
MF.addMatcher(declRefExpr(to(enumConstantDecl(matchesName(
"(CUBLAS_STATUS.*)|(CUDA_R_.*)|(CUDA_C_.*)|("
"CUBLAS_GEMM_.*)|(CUBLAS_POINTER_MODE.*)"))))
"CUBLAS_GEMM_.*)|(CUBLAS_POINTER_MODE.*)|("
"CUBLASLT_EPILOGUE_.*)"))))
.bind("BLASStatusConstants"),
this);
MF.addMatcher(
declRefExpr(to(enumConstantDecl(matchesName(
"(CUBLAS_OP.*)|(CUBLAS_SIDE.*)|(CUBLAS_FILL_"
"MODE.*)|(CUBLAS_DIAG.*)|(CUBLAS_.*_MATH)|CUBLAS_MATH_"
"DISALLOW_REDUCED_PRECISION_REDUCTION"))))
"DISALLOW_REDUCED_PRECISION_REDUCTION|(CUBLASLT_ORDER_.*)"
"|(CUBLASLT_POINTER_MODE_.*)|(CUBLASLT_MATRIX_LAYOUT_.*)|"
"(CUBLASLT_MATMUL_DESC_.*)|(CUBLASLT_MATRIX_TRANSFORM_"
"DESC_.*)"))))
.bind("BLASNamedValueConstants"),
this);
}
Expand Down Expand Up @@ -4349,7 +4360,20 @@ void BLASFunctionCallRule::registerMatcher(MatchFinder &MF) {
"cublasDsyr2k_v2_64", "cublasCsyr2k_v2_64", "cublasZsyr2k_v2_64",
"cublasCher2k_v2_64", "cublasZher2k_v2_64", "cublasSgeam_64",
"cublasDgeam_64", "cublasCgeam_64", "cublasZgeam_64", "cublasSdgmm_64",
"cublasDdgmm_64", "cublasCdgmm_64", "cublasZdgmm_64");
"cublasDdgmm_64", "cublasCdgmm_64", "cublasZdgmm_64",
/*cublasLt*/
"cublasLtCreate", "cublasLtDestroy", "cublasLtMatmulDescCreate",
"cublasLtMatmulDescDestroy", "cublasLtMatmulDescSetAttribute",
"cublasLtMatmulDescGetAttribute", "cublasLtMatrixLayoutCreate",
"cublasLtMatrixLayoutDestroy", "cublasLtMatrixLayoutGetAttribute",
"cublasLtMatrixLayoutSetAttribute", "cublasLtMatmul",
"cublasLtMatmulPreferenceCreate", "cublasLtMatmulPreferenceDestroy",
"cublasLtMatmulPreferenceSetAttribute",
"cublasLtMatmulPreferenceGetAttribute",
"cublasLtMatmulAlgoGetHeuristic", "cublasLtMatrixTransformDescCreate",
"cublasLtMatrixTransformDescDestroy",
"cublasLtMatrixTransformDescSetAttribute",
"cublasLtMatrixTransformDescGetAttribute", "cublasLtMatrixTransform");
};

MF.addMatcher(callExpr(allOf(callee(functionDecl(functionName())),
Expand Down
2 changes: 2 additions & 0 deletions clang/lib/DPCT/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ set(RUNTIME_HEADERS
${CMAKE_SOURCE_DIR}/../clang/runtime/dpct-rt/include/dpct/fft_utils.hpp
${CMAKE_SOURCE_DIR}/../clang/runtime/dpct-rt/include/dpct/lapack_utils.hpp
${CMAKE_SOURCE_DIR}/../clang/runtime/dpct-rt/include/dpct/group_utils.hpp
${CMAKE_SOURCE_DIR}/../clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp
${CMAKE_SOURCE_DIR}/../clang/runtime/dpct-rt/include/dpct/dpl_extras/algorithm.h
${CMAKE_SOURCE_DIR}/../clang/runtime/dpct-rt/include/dpct/dpl_extras/functional.h
${CMAKE_SOURCE_DIR}/../clang/runtime/dpct-rt/include/dpct/dpl_extras/iterators.h
Expand Down Expand Up @@ -62,6 +63,7 @@ set(PROCESS_FILES_OUTPUT
${CMAKE_BINARY_DIR}/tools/clang/include/clang/DPCT/fft_utils.hpp.inc
${CMAKE_BINARY_DIR}/tools/clang/include/clang/DPCT/lapack_utils.hpp.inc
${CMAKE_BINARY_DIR}/tools/clang/include/clang/DPCT/group_utils.hpp.inc
${CMAKE_BINARY_DIR}/tools/clang/include/clang/DPCT/blas_gemm_utils.hpp.inc
)

add_custom_command(
Expand Down
4 changes: 4 additions & 0 deletions clang/lib/DPCT/GenHelperFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ const std::string LapackUtilsAllContentStr =
const std::string GroupUtilsAllContentStr =
#include "clang/DPCT/group_utils.hpp.inc"
;
const std::string BlasGemmUtilsAllContentStr =
#include "clang/DPCT/blas_gemm_utils.hpp.inc"
;
const std::string DplExtrasAlgorithmAllContentStr =
#include "clang/DPCT/dpl_extras/algorithm.h.inc"
;
Expand Down Expand Up @@ -170,6 +173,7 @@ void genHelperFunction(const clang::tooling::UnifiedPath &OutRoot) {
GENERATE_ALL_FILE_CONTENT(FftUtils, ".", fft_utils.hpp)
GENERATE_ALL_FILE_CONTENT(LapackUtils, ".", lapack_utils.hpp)
GENERATE_ALL_FILE_CONTENT(GroupUtils, ".", group_utils.hpp)
GENERATE_ALL_FILE_CONTENT(BlasGemmUtils, ".", blas_gemm_utils.hpp)
GENERATE_ALL_FILE_CONTENT(CodePin, "codepin", codepin.hpp)
GENERATE_ALL_FILE_CONTENT(CodePinSerializationBasic, "codepin/serialization",
basic.hpp)
Expand Down
1 change: 1 addition & 0 deletions clang/lib/DPCT/HeaderTypes.inc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ DPCT_HEADER(COMMON_Utils, "<dpct/lib_common_utils.hpp>")
DPCT_HEADER(Atomic, "<dpct/atomic.hpp>")
DPCT_HEADER(SPBLAS_Utils, "<dpct/sparse_utils.hpp>")
DPCT_HEADER(Math, "<dpct/math.hpp>")
DPCT_HEADER(BLAS_GEMM_Utils, "<dpct/blas_gemm_utils.hpp>")
DPCT_HEADER(CodePin_SYCL, "<dpct/codepin/codepin.hpp>")
DPCT_HEADER(CodePin_CUDA, "<dpct/codepin/codepin.hpp>")
DPCT_HEADER(Graph, "<dpct/graph.hpp>")
Expand Down
6 changes: 4 additions & 2 deletions clang/lib/DPCT/InclusionHeaders.inc
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ REGIST_INCLUSION("curand_kernel.h", FullMatch, Rng, Replace, false,
HeaderType::HT_DPCT_RNG_Utils)

REGIST_INCLUSION("cusparse.h", FullMatch, Sparse, Replace, false,
HeaderType::HT_DPCT_SPBLAS_Utils,
HeaderType::HT_DPCT_BLAS_Utils)
HeaderType::HT_DPCT_SPBLAS_Utils)
REGIST_INCLUSION("cusparse_v2.h", FullMatch, Sparse, Replace, false,
HeaderType::HT_DPCT_SPBLAS_Utils,
HeaderType::HT_DPCT_BLAS_Utils)
Expand Down Expand Up @@ -106,3 +105,6 @@ REGIST_INCLUSION("cub/", Startwith, Thrust, Replace, false,
REGIST_INCLUSION("CL/", Startwith, Common, DoNothing, false)

REGIST_INCLUSION("cuda_rutime.h", FullMatch, Libcu, Remove, true)

REGIST_INCLUSION("cublasLt.h", FullMatch, BLas, Replace, false,
HeaderType::HT_DPCT_BLAS_GEMM_Utils)
Loading

0 comments on commit dcd88af

Please sign in to comment.