diff --git a/clang/lib/DPCT/APINamesCUBLAS.inc b/clang/lib/DPCT/APINamesCUBLAS.inc index b0d8b8d4bcb7..eb3eadb04e50 100644 --- a/clang/lib/DPCT/APINamesCUBLAS.inc +++ b/clang/lib/DPCT/APINamesCUBLAS.inc @@ -1634,3 +1634,64 @@ GEMM_BATCH(cublasDgemmStridedBatched, "double", false) GEMM_BATCH(cublasCgemmStridedBatched, "std::complex", true) GEMM_BATCH(cublasZgemmStridedBatched, "std::complex", true) #undef GEMM_BATCH + +ASSIGNABLE_FACTORY(ASSIGN_FACTORY_ENTRY( + "cublasLtCreate", DEREF(0), + NEW(MapNames::getDpctNamespace() + "blas_gemm::experimental::descriptor"))) +ASSIGNABLE_FACTORY(DELETE_FACTORY_ENTRY("cublasLtDestroy", ARG(0))) +ASSIGNABLE_FACTORY(ASSIGN_FACTORY_ENTRY( + "cublasLtMatmulDescCreate", DEREF(0), + NEW(MapNames::getDpctNamespace() + "blas_gemm::experimental::matmul_desc_t", + ARG(1), ARG(2)))) +ASSIGNABLE_FACTORY(DELETE_FACTORY_ENTRY("cublasLtMatmulDescDestroy", ARG(0))) +ASSIGNABLE_FACTORY(MEMBER_CALL_FACTORY_ENTRY("cublasLtMatmulDescSetAttribute", + ARG(0), true, "set_attribute", + ARG(1), ARG(2))) +ASSIGNABLE_FACTORY(MEMBER_CALL_FACTORY_ENTRY("cublasLtMatmulDescGetAttribute", + ARG(0), true, "get_attribute", + ARG(1), ARG(2))) + +ASSIGNABLE_FACTORY( + ASSIGN_FACTORY_ENTRY("cublasLtMatrixLayoutCreate", DEREF(0), + NEW(MapNames::getDpctNamespace() + + "blas_gemm::experimental::matrix_layout_t", + ARG(1), ARG(2), ARG(3), ARG(4)))) +ASSIGNABLE_FACTORY(DELETE_FACTORY_ENTRY("cublasLtMatrixLayoutDestroy", ARG(0))) +ASSIGNABLE_FACTORY(MEMBER_CALL_FACTORY_ENTRY("cublasLtMatrixLayoutSetAttribute", + ARG(0), true, "set_attribute", + ARG(1), ARG(2))) +ASSIGNABLE_FACTORY(MEMBER_CALL_FACTORY_ENTRY("cublasLtMatrixLayoutGetAttribute", + ARG(0), true, "get_attribute", + ARG(1), ARG(2))) + +ASSIGNABLE_FACTORY(CALL_FACTORY_ENTRY( + "cublasLtMatmul", + CALL(MapNames::getDpctNamespace() + "blas_gemm::experimental::matmul", + ARG(0), ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), ARG(7), ARG(8), + ARG(9), ARG(10), ARG(11), ARG(15)))) +REMOVE_API_FACTORY_ENTRY("cublasLtMatmulPreferenceCreate") +REMOVE_API_FACTORY_ENTRY("cublasLtMatmulPreferenceDestroy") +REMOVE_API_FACTORY_ENTRY("cublasLtMatmulPreferenceSetAttribute") +REMOVE_API_FACTORY_ENTRY("cublasLtMatmulPreferenceGetAttribute") +ASSIGNABLE_FACTORY(ASSIGN_FACTORY_ENTRY("cublasLtMatmulAlgoGetHeuristic", + DEREF(9), ARG("1"))) + +ASSIGNABLE_FACTORY( + ASSIGN_FACTORY_ENTRY("cublasLtMatrixTransformDescCreate", DEREF(0), + NEW(MapNames::getDpctNamespace() + + "blas_gemm::experimental::transform_desc_t", + ARG(1)))) +ASSIGNABLE_FACTORY(DELETE_FACTORY_ENTRY("cublasLtMatrixTransformDescDestroy", + ARG(0))) +ASSIGNABLE_FACTORY( + MEMBER_CALL_FACTORY_ENTRY("cublasLtMatrixTransformDescSetAttribute", ARG(0), + true, "set_attribute", ARG(1), ARG(2))) +ASSIGNABLE_FACTORY( + MEMBER_CALL_FACTORY_ENTRY("cublasLtMatrixTransformDescGetAttribute", ARG(0), + true, "get_attribute", ARG(1), ARG(2))) +ASSIGNABLE_FACTORY( + CALL_FACTORY_ENTRY("cublasLtMatrixTransform", + CALL(MapNames::getDpctNamespace() + + "blas_gemm::experimental::matrix_transform", + ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), + ARG(7), ARG(8), ARG(9), ARG(10)))) diff --git a/clang/lib/DPCT/APINames_cuBLAS.inc b/clang/lib/DPCT/APINames_cuBLAS.inc index 6064792ad9d1..9fed09697ba4 100644 --- a/clang/lib/DPCT/APINames_cuBLAS.inc +++ b/clang/lib/DPCT/APINames_cuBLAS.inc @@ -287,31 +287,31 @@ ENTRY(cublasNrm2Ex, cublasNrm2Ex, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasAxpyEx, cublasAxpyEx, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasDotEx, cublasDotEx, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasScalEx, cublasScalEx, true, NO_FLAG, P4, "DPCT1020") -ENTRY(cublasLtCreate, cublasLtCreate, false, NO_FLAG, P4, "comment") -ENTRY(cublasLtDestroy, cublasLtDestroy, false, NO_FLAG, P4, "comment") +ENTRY(cublasLtCreate, cublasLtCreate, true, NO_FLAG, P4, "DPCT1026/DPCT1027") +ENTRY(cublasLtDestroy, cublasLtDestroy, true, NO_FLAG, P4, "DPCT1026/DPCT1027") ENTRY(cublasLtGetVersion, cublasLtGetVersion, false, NO_FLAG, P4, "comment") ENTRY(cublasLtGetCudartVersion, cublasLtGetCudartVersion, false, NO_FLAG, P4, "comment") ENTRY(cublasLtGetProperty, cublasLtGetProperty, false, NO_FLAG, P4, "comment") -ENTRY(cublasLtMatmul, cublasLtMatmul, false, NO_FLAG, P4, "comment") -ENTRY(cublasLtMatrixLayoutCreate, cublasLtMatrixLayoutCreate, false, NO_FLAG, P4, "comment") -ENTRY(cublasLtMatrixLayoutGetAttribute, cublasLtMatrixLayoutGetAttribute, false, NO_FLAG, P4, "comment") -ENTRY(cublasLtMatrixLayoutSetAttribute, cublasLtMatrixLayoutSetAttribute, false, NO_FLAG, P4, "comment") -ENTRY(cublasLtMatrixLayoutDestroy, cublasLtMatrixLayoutDestroy, false, NO_FLAG, P4, "comment") -ENTRY(cublasLtMatrixTransform, cublasLtMatrixTransform, false, NO_FLAG, P4, "comment") -ENTRY(cublasLtMatmulDescCreate, cublasLtMatmulDescCreate, false, NO_FLAG, P4, "comment") -ENTRY(cublasLtMatmulDescDestroy, cublasLtMatmulDescDestroy, false, NO_FLAG, P4, "comment") -ENTRY(cublasLtMatmulDescSetAttribute, cublasLtMatmulDescSetAttribute, false, NO_FLAG, P4, "comment") -ENTRY(cublasLtMatmulDescGetAttribute, cublasLtMatmulDescGetAttribute, false, NO_FLAG, P4, "comment") -ENTRY(cublasLtMatrixTransformDescCreate, cublasLtMatrixTransformDescCreate, false, NO_FLAG, P4, "comment") -ENTRY(cublasLtMatrixTransformDescDestroy, cublasLtMatrixTransformDescDestroy, false, NO_FLAG, P4, "comment") -ENTRY(cublasLtMatrixTransformDescSetAttribute, cublasLtMatrixTransformDescSetAttribute, false, NO_FLAG, P4, "comment") -ENTRY(cublasLtMatrixTransformDescGetAttribute, cublasLtMatrixTransformDescGetAttribute, false, NO_FLAG, P4, "comment") -ENTRY(cublasLtMatmulPreferenceCreate, cublasLtMatmulPreferenceCreate, false, NO_FLAG, P4, "comment") -ENTRY(cublasLtMatmulPreferenceDestroy, cublasLtMatmulPreferenceDestroy, false, NO_FLAG, P4, "comment") -ENTRY(cublasLtMatmulPreferenceSetAttribute, cublasLtMatmulPreferenceSetAttribute, false, NO_FLAG, P4, "comment") -ENTRY(cublasLtMatmulPreferenceGetAttribute, cublasLtMatmulPreferenceGetAttribute, false, NO_FLAG, P4, "comment") +ENTRY(cublasLtMatmul, cublasLtMatmul, true, NO_FLAG, P4, "comment") +ENTRY(cublasLtMatrixLayoutCreate, cublasLtMatrixLayoutCreate, true, NO_FLAG, P4, "comment") +ENTRY(cublasLtMatrixLayoutGetAttribute, cublasLtMatrixLayoutGetAttribute, true, NO_FLAG, P4, "comment") +ENTRY(cublasLtMatrixLayoutSetAttribute, cublasLtMatrixLayoutSetAttribute, true, NO_FLAG, P4, "comment") +ENTRY(cublasLtMatrixLayoutDestroy, cublasLtMatrixLayoutDestroy, true, NO_FLAG, P4, "comment") +ENTRY(cublasLtMatrixTransform, cublasLtMatrixTransform, true, NO_FLAG, P4, "comment") +ENTRY(cublasLtMatmulDescCreate, cublasLtMatmulDescCreate, true, NO_FLAG, P4, "comment") +ENTRY(cublasLtMatmulDescDestroy, cublasLtMatmulDescDestroy, true, NO_FLAG, P4, "comment") +ENTRY(cublasLtMatmulDescSetAttribute, cublasLtMatmulDescSetAttribute, true, NO_FLAG, P4, "comment") +ENTRY(cublasLtMatmulDescGetAttribute, cublasLtMatmulDescGetAttribute, true, NO_FLAG, P4, "comment") +ENTRY(cublasLtMatrixTransformDescCreate, cublasLtMatrixTransformDescCreate, true, NO_FLAG, P4, "comment") +ENTRY(cublasLtMatrixTransformDescDestroy, cublasLtMatrixTransformDescDestroy, true, NO_FLAG, P4, "comment") +ENTRY(cublasLtMatrixTransformDescSetAttribute, cublasLtMatrixTransformDescSetAttribute, true, NO_FLAG, P4, "comment") +ENTRY(cublasLtMatrixTransformDescGetAttribute, cublasLtMatrixTransformDescGetAttribute, true, NO_FLAG, P4, "comment") +ENTRY(cublasLtMatmulPreferenceCreate, cublasLtMatmulPreferenceCreate, true, NO_FLAG, P4, "comment") +ENTRY(cublasLtMatmulPreferenceDestroy, cublasLtMatmulPreferenceDestroy, true, NO_FLAG, P4, "comment") +ENTRY(cublasLtMatmulPreferenceSetAttribute, cublasLtMatmulPreferenceSetAttribute, true, NO_FLAG, P4, "comment") +ENTRY(cublasLtMatmulPreferenceGetAttribute, cublasLtMatmulPreferenceGetAttribute, true, NO_FLAG, P4, "comment") ENTRY(cublasLtMatmulAlgoCheck, cublasLtMatmulAlgoCheck, false, NO_FLAG, P4, "comment") -ENTRY(cublasLtMatmulAlgoGetHeuristic, cublasLtMatmulAlgoGetHeuristic, false, NO_FLAG, P4, "comment") +ENTRY(cublasLtMatmulAlgoGetHeuristic, cublasLtMatmulAlgoGetHeuristic, true, NO_FLAG, P4, "comment") ENTRY(cublasLtMatmulAlgoGetIds, cublasLtMatmulAlgoGetIds, false, NO_FLAG, P4, "comment") ENTRY(cublasLtMatmulAlgoInit, cublasLtMatmulAlgoInit, false, NO_FLAG, P4, "comment") ENTRY(cublasLtMatmulAlgoCapGetAttribute, cublasLtMatmulAlgoCapGetAttribute, false, NO_FLAG, P4, "comment") diff --git a/clang/lib/DPCT/ASTTraversal.cpp b/clang/lib/DPCT/ASTTraversal.cpp index 4dc3658e6228..1e1915b7915a 100644 --- a/clang/lib/DPCT/ASTTraversal.cpp +++ b/clang/lib/DPCT/ASTTraversal.cpp @@ -1738,7 +1738,14 @@ void TypeInDeclRule::registerMatcher(MatchFinder &MF) { "cudaLaunchAttributeValue", "cusparseSpSMDescr_t", "cusparseConstSpMatDescr_t", "cusparseSpSMAlg_t", "cusparseConstDnMatDescr_t", "cudaMemcpy3DParms", "CUDA_MEMCPY3D", - "CUDA_MEMCPY2D", "CUDA_ARRAY_DESCRIPTOR")))))) + "CUDA_MEMCPY2D", "CUDA_ARRAY_DESCRIPTOR", "cublasLtHandle_t", + "cublasLtMatmulDesc_t", "cublasLtOrder_t", + "cublasLtPointerMode_t", "cublasLtMatrixLayout_t", + "cublasLtMatrixLayoutAttribute_t", + "cublasLtMatmulDescAttributes_t", "cublasLtMatmulAlgo_t", + "cublasLtEpilogue_t", "cublasLtMatmulPreference_t", + "cublasLtMatmulHeuristicResult_t", + "cublasLtMatrixTransformDesc_t")))))) .bind("cudaTypeDef"), this); @@ -3605,14 +3612,18 @@ REGISTER_RULE(CU_JITEnumsRule, PassKind::PK_Migration) void BLASEnumsRule::registerMatcher(MatchFinder &MF) { MF.addMatcher(declRefExpr(to(enumConstantDecl(matchesName( "(CUBLAS_STATUS.*)|(CUDA_R_.*)|(CUDA_C_.*)|(" - "CUBLAS_GEMM_.*)|(CUBLAS_POINTER_MODE.*)")))) + "CUBLAS_GEMM_.*)|(CUBLAS_POINTER_MODE.*)|(" + "CUBLASLT_EPILOGUE_.*)")))) .bind("BLASStatusConstants"), this); MF.addMatcher( declRefExpr(to(enumConstantDecl(matchesName( "(CUBLAS_OP.*)|(CUBLAS_SIDE.*)|(CUBLAS_FILL_" "MODE.*)|(CUBLAS_DIAG.*)|(CUBLAS_.*_MATH)|CUBLAS_MATH_" - "DISALLOW_REDUCED_PRECISION_REDUCTION")))) + "DISALLOW_REDUCED_PRECISION_REDUCTION|(CUBLASLT_ORDER_.*)" + "|(CUBLASLT_POINTER_MODE_.*)|(CUBLASLT_MATRIX_LAYOUT_.*)|" + "(CUBLASLT_MATMUL_DESC_.*)|(CUBLASLT_MATRIX_TRANSFORM_" + "DESC_.*)")))) .bind("BLASNamedValueConstants"), this); } @@ -4349,7 +4360,20 @@ void BLASFunctionCallRule::registerMatcher(MatchFinder &MF) { "cublasDsyr2k_v2_64", "cublasCsyr2k_v2_64", "cublasZsyr2k_v2_64", "cublasCher2k_v2_64", "cublasZher2k_v2_64", "cublasSgeam_64", "cublasDgeam_64", "cublasCgeam_64", "cublasZgeam_64", "cublasSdgmm_64", - "cublasDdgmm_64", "cublasCdgmm_64", "cublasZdgmm_64"); + "cublasDdgmm_64", "cublasCdgmm_64", "cublasZdgmm_64", + /*cublasLt*/ + "cublasLtCreate", "cublasLtDestroy", "cublasLtMatmulDescCreate", + "cublasLtMatmulDescDestroy", "cublasLtMatmulDescSetAttribute", + "cublasLtMatmulDescGetAttribute", "cublasLtMatrixLayoutCreate", + "cublasLtMatrixLayoutDestroy", "cublasLtMatrixLayoutGetAttribute", + "cublasLtMatrixLayoutSetAttribute", "cublasLtMatmul", + "cublasLtMatmulPreferenceCreate", "cublasLtMatmulPreferenceDestroy", + "cublasLtMatmulPreferenceSetAttribute", + "cublasLtMatmulPreferenceGetAttribute", + "cublasLtMatmulAlgoGetHeuristic", "cublasLtMatrixTransformDescCreate", + "cublasLtMatrixTransformDescDestroy", + "cublasLtMatrixTransformDescSetAttribute", + "cublasLtMatrixTransformDescGetAttribute", "cublasLtMatrixTransform"); }; MF.addMatcher(callExpr(allOf(callee(functionDecl(functionName())), diff --git a/clang/lib/DPCT/CMakeLists.txt b/clang/lib/DPCT/CMakeLists.txt index 4f37303bea06..f916fc97e41a 100644 --- a/clang/lib/DPCT/CMakeLists.txt +++ b/clang/lib/DPCT/CMakeLists.txt @@ -24,6 +24,7 @@ set(RUNTIME_HEADERS ${CMAKE_SOURCE_DIR}/../clang/runtime/dpct-rt/include/dpct/fft_utils.hpp ${CMAKE_SOURCE_DIR}/../clang/runtime/dpct-rt/include/dpct/lapack_utils.hpp ${CMAKE_SOURCE_DIR}/../clang/runtime/dpct-rt/include/dpct/group_utils.hpp + ${CMAKE_SOURCE_DIR}/../clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp ${CMAKE_SOURCE_DIR}/../clang/runtime/dpct-rt/include/dpct/dpl_extras/algorithm.h ${CMAKE_SOURCE_DIR}/../clang/runtime/dpct-rt/include/dpct/dpl_extras/functional.h ${CMAKE_SOURCE_DIR}/../clang/runtime/dpct-rt/include/dpct/dpl_extras/iterators.h @@ -62,6 +63,7 @@ set(PROCESS_FILES_OUTPUT ${CMAKE_BINARY_DIR}/tools/clang/include/clang/DPCT/fft_utils.hpp.inc ${CMAKE_BINARY_DIR}/tools/clang/include/clang/DPCT/lapack_utils.hpp.inc ${CMAKE_BINARY_DIR}/tools/clang/include/clang/DPCT/group_utils.hpp.inc + ${CMAKE_BINARY_DIR}/tools/clang/include/clang/DPCT/blas_gemm_utils.hpp.inc ) add_custom_command( diff --git a/clang/lib/DPCT/GenHelperFunction.cpp b/clang/lib/DPCT/GenHelperFunction.cpp index d5b570b3fa8e..96c8e9ab6ed3 100644 --- a/clang/lib/DPCT/GenHelperFunction.cpp +++ b/clang/lib/DPCT/GenHelperFunction.cpp @@ -79,6 +79,9 @@ const std::string LapackUtilsAllContentStr = const std::string GroupUtilsAllContentStr = #include "clang/DPCT/group_utils.hpp.inc" ; +const std::string BlasGemmUtilsAllContentStr = +#include "clang/DPCT/blas_gemm_utils.hpp.inc" + ; const std::string DplExtrasAlgorithmAllContentStr = #include "clang/DPCT/dpl_extras/algorithm.h.inc" ; @@ -170,6 +173,7 @@ void genHelperFunction(const clang::tooling::UnifiedPath &OutRoot) { GENERATE_ALL_FILE_CONTENT(FftUtils, ".", fft_utils.hpp) GENERATE_ALL_FILE_CONTENT(LapackUtils, ".", lapack_utils.hpp) GENERATE_ALL_FILE_CONTENT(GroupUtils, ".", group_utils.hpp) + GENERATE_ALL_FILE_CONTENT(BlasGemmUtils, ".", blas_gemm_utils.hpp) GENERATE_ALL_FILE_CONTENT(CodePin, "codepin", codepin.hpp) GENERATE_ALL_FILE_CONTENT(CodePinSerializationBasic, "codepin/serialization", basic.hpp) diff --git a/clang/lib/DPCT/HeaderTypes.inc b/clang/lib/DPCT/HeaderTypes.inc index 3d138288eecb..de9af20a6c41 100644 --- a/clang/lib/DPCT/HeaderTypes.inc +++ b/clang/lib/DPCT/HeaderTypes.inc @@ -70,6 +70,7 @@ DPCT_HEADER(COMMON_Utils, "") DPCT_HEADER(Atomic, "") DPCT_HEADER(SPBLAS_Utils, "") DPCT_HEADER(Math, "") +DPCT_HEADER(BLAS_GEMM_Utils, "") DPCT_HEADER(CodePin_SYCL, "") DPCT_HEADER(CodePin_CUDA, "") DPCT_HEADER(Graph, "") diff --git a/clang/lib/DPCT/InclusionHeaders.inc b/clang/lib/DPCT/InclusionHeaders.inc index 26caa59b1ec3..87991dc3ae14 100644 --- a/clang/lib/DPCT/InclusionHeaders.inc +++ b/clang/lib/DPCT/InclusionHeaders.inc @@ -48,8 +48,7 @@ REGIST_INCLUSION("curand_kernel.h", FullMatch, Rng, Replace, false, HeaderType::HT_DPCT_RNG_Utils) REGIST_INCLUSION("cusparse.h", FullMatch, Sparse, Replace, false, - HeaderType::HT_DPCT_SPBLAS_Utils, - HeaderType::HT_DPCT_BLAS_Utils) + HeaderType::HT_DPCT_SPBLAS_Utils) REGIST_INCLUSION("cusparse_v2.h", FullMatch, Sparse, Replace, false, HeaderType::HT_DPCT_SPBLAS_Utils, HeaderType::HT_DPCT_BLAS_Utils) @@ -106,3 +105,6 @@ REGIST_INCLUSION("cub/", Startwith, Thrust, Replace, false, REGIST_INCLUSION("CL/", Startwith, Common, DoNothing, false) REGIST_INCLUSION("cuda_rutime.h", FullMatch, Libcu, Remove, true) + +REGIST_INCLUSION("cublasLt.h", FullMatch, BLas, Replace, false, + HeaderType::HT_DPCT_BLAS_GEMM_Utils) diff --git a/clang/lib/DPCT/MapNames.cpp b/clang/lib/DPCT/MapNames.cpp index d4a1abc8b14c..5bc8b25945af 100644 --- a/clang/lib/DPCT/MapNames.cpp +++ b/clang/lib/DPCT/MapNames.cpp @@ -200,8 +200,8 @@ void MapNames::setExplicitNamespaceMap() { {"cublasDataType_t", std::make_shared(getDpctNamespace() + "library_data_t", HelperFeatureEnum::device_ext)}, - {"cublasComputeType_t", std::make_shared( - getDpctNamespace() + "blas::compute_type")}, + {"cublasComputeType_t", + std::make_shared(getDpctNamespace() + "compute_type")}, {"cuComplex", std::make_shared(getClNamespace() + "float2")}, {"cuFloatComplex", @@ -519,6 +519,37 @@ void MapNames::setExplicitNamespaceMap() { {"cudaLaunchAttributeValue", std::make_shared("int")}, {"cusparseSpSMDescr_t", std::make_shared("int")}, {"cusparseSpSMAlg_t", std::make_shared("int")}, + {"cublasLtHandle_t", + std::make_shared( + getDpctNamespace() + "blas_gemm::experimental::descriptor_ptr")}, + {"cublasLtMatmulDesc_t", + std::make_shared( + getDpctNamespace() + "blas_gemm::experimental::matmul_desc_ptr")}, + {"cublasLtOrder_t", + std::make_shared(getDpctNamespace() + + "blas_gemm::experimental::order_t")}, + {"cublasLtPointerMode_t", + std::make_shared( + getDpctNamespace() + "blas_gemm::experimental::pointer_mode_t")}, + {"cublasLtMatrixLayout_t", + std::make_shared( + getDpctNamespace() + "blas_gemm::experimental::matrix_layout_ptr")}, + {"cublasLtMatrixLayoutAttribute_t", + std::make_shared( + getDpctNamespace() + + "blas_gemm::experimental::matrix_layout_t::attribute")}, + {"cublasLtMatmulDescAttributes_t", + std::make_shared( + getDpctNamespace() + + "blas_gemm::experimental::matmul_desc_t::attribute")}, + {"cublasLtMatmulAlgo_t", std::make_shared("int")}, + {"cublasLtEpilogue_t", std::make_shared("int")}, + {"cublasLtMatmulPreference_t", std::make_shared("int")}, + {"cublasLtMatmulHeuristicResult_t", + std::make_shared("int")}, + {"cublasLtMatrixTransformDesc_t", + std::make_shared( + getDpctNamespace() + "blas_gemm::experimental::transform_desc_ptr")}, // ... }; @@ -1239,39 +1270,35 @@ void MapNames::setExplicitNamespaceMap() { std::make_shared(getDpctNamespace() + "library_data_t::real_f8_e5m2")}, // cublasComputeType_t - {"CUBLAS_COMPUTE_16F", - std::make_shared(getDpctNamespace() + - "blas::compute_type::f16")}, + {"CUBLAS_COMPUTE_16F", std::make_shared( + getDpctNamespace() + "compute_type::f16")}, {"CUBLAS_COMPUTE_16F_PEDANTIC", std::make_shared(getDpctNamespace() + - "blas::compute_type::f16_standard")}, - {"CUBLAS_COMPUTE_32F", - std::make_shared(getDpctNamespace() + - "blas::compute_type::f32")}, + "compute_type::f16_standard")}, + {"CUBLAS_COMPUTE_32F", std::make_shared( + getDpctNamespace() + "compute_type::f32")}, {"CUBLAS_COMPUTE_32F_PEDANTIC", std::make_shared(getDpctNamespace() + - "blas::compute_type::f32_standard")}, + "compute_type::f32_standard")}, {"CUBLAS_COMPUTE_32F_FAST_16F", std::make_shared(getDpctNamespace() + - "blas::compute_type::f32")}, + "compute_type::f32")}, {"CUBLAS_COMPUTE_32F_FAST_16BF", std::make_shared(getDpctNamespace() + - "blas::compute_type::f32_fast_bf16")}, + "compute_type::f32_fast_bf16")}, {"CUBLAS_COMPUTE_32F_FAST_TF32", std::make_shared(getDpctNamespace() + - "blas::compute_type::f32_fast_tf32")}, - {"CUBLAS_COMPUTE_64F", - std::make_shared(getDpctNamespace() + - "blas::compute_type::f64")}, + "compute_type::f32_fast_tf32")}, + {"CUBLAS_COMPUTE_64F", std::make_shared( + getDpctNamespace() + "compute_type::f64")}, {"CUBLAS_COMPUTE_64F_PEDANTIC", std::make_shared(getDpctNamespace() + - "blas::compute_type::f64_standard")}, - {"CUBLAS_COMPUTE_32I", - std::make_shared(getDpctNamespace() + - "blas::compute_type::i32")}, + "compute_type::f64_standard")}, + {"CUBLAS_COMPUTE_32I", std::make_shared( + getDpctNamespace() + "compute_type::i32")}, {"CUBLAS_COMPUTE_32I_PEDANTIC", std::make_shared(getDpctNamespace() + - "blas::compute_type::i32_standard")}, + "compute_type::i32_standard")}, {"cuda::thread_scope_system", std::make_shared(getClNamespace() + "memory_scope::system")}, @@ -1425,6 +1452,77 @@ void MapNames::setExplicitNamespaceMap() { getDpctNamespace() + "blas::math_mode::mm_tf32"}, {"CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION", getDpctNamespace() + "blas::math_mode::mm_default"}, + {"CUBLASLT_ORDER_COL", + getDpctNamespace() + "blas_gemm::experimental::order_t::col"}, + {"CUBLASLT_ORDER_ROW", + getDpctNamespace() + "blas_gemm::experimental::order_t::row"}, + {"CUBLASLT_ORDER_COL32", + getDpctNamespace() + "blas_gemm::experimental::order_t::col32"}, + {"CUBLASLT_ORDER_COL4_4R2_8C", + getDpctNamespace() + "blas_gemm::experimental::order_t::col4_4r2_8c"}, + {"CUBLASLT_ORDER_COL32_2R_4R4", + getDpctNamespace() + "blas_gemm::experimental::order_t::col32_2r_4r4"}, + {"CUBLASLT_POINTER_MODE_HOST", + getDpctNamespace() + "blas_gemm::experimental::pointer_mode_t::host"}, + {"CUBLASLT_POINTER_MODE_DEVICE", + getDpctNamespace() + "blas_gemm::experimental::pointer_mode_t::device"}, + {"CUBLASLT_POINTER_MODE_DEVICE_VECTOR", + getDpctNamespace() + + "blas_gemm::experimental::pointer_mode_t::device_vector"}, + {"CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO", + getDpctNamespace() + "blas_gemm::experimental::pointer_mode_t::alpha_" + "device_vector_beta_zero"}, + {"CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST", + getDpctNamespace() + "blas_gemm::experimental::pointer_mode_t::alpha_" + "device_vector_beta_host"}, + {"CUBLASLT_MATRIX_LAYOUT_TYPE", + getDpctNamespace() + + "blas_gemm::experimental::matrix_layout_t::attribute::type"}, + {"CUBLASLT_MATRIX_LAYOUT_ORDER", + getDpctNamespace() + + "blas_gemm::experimental::matrix_layout_t::attribute::order"}, + {"CUBLASLT_MATRIX_LAYOUT_ROWS", + getDpctNamespace() + + "blas_gemm::experimental::matrix_layout_t::attribute::rows"}, + {"CUBLASLT_MATRIX_LAYOUT_COLS", + getDpctNamespace() + + "blas_gemm::experimental::matrix_layout_t::attribute::cols"}, + {"CUBLASLT_MATRIX_LAYOUT_LD", + getDpctNamespace() + + "blas_gemm::experimental::matrix_layout_t::attribute::ld"}, + {"CUBLASLT_MATMUL_DESC_COMPUTE_TYPE", + getDpctNamespace() + + "blas_gemm::experimental::matmul_desc_t::attribute::compute_type"}, + {"CUBLASLT_MATMUL_DESC_SCALE_TYPE", + getDpctNamespace() + + "blas_gemm::experimental::matmul_desc_t::attribute::scale_type"}, + {"CUBLASLT_MATMUL_DESC_POINTER_MODE", + getDpctNamespace() + + "blas_gemm::experimental::matmul_desc_t::attribute::pointer_mode"}, + {"CUBLASLT_MATMUL_DESC_TRANSA", + getDpctNamespace() + + "blas_gemm::experimental::matmul_desc_t::attribute::trans_a"}, + {"CUBLASLT_MATMUL_DESC_TRANSB", + getDpctNamespace() + + "blas_gemm::experimental::matmul_desc_t::attribute::trans_b"}, + {"CUBLASLT_MATMUL_DESC_TRANSC", + getDpctNamespace() + + "blas_gemm::experimental::matmul_desc_t::attribute::trans_c"}, + {"CUBLASLT_MATMUL_DESC_EPILOGUE", + getDpctNamespace() + + "blas_gemm::experimental::matmul_desc_t::attribute::epilogue"}, + {"CUBLASLT_MATRIX_TRANSFORM_DESC_SCALE_TYPE", + getDpctNamespace() + + "blas_gemm::experimental::transform_desc_t::attribute::scale_type"}, + {"CUBLASLT_MATRIX_TRANSFORM_DESC_POINTER_MODE", + getDpctNamespace() + "blas_gemm::experimental::transform_desc_t::" + "attribute::pointer_mode"}, + {"CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA", + getDpctNamespace() + + "blas_gemm::experimental::transform_desc_t::attribute::trans_a"}, + {"CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSB", + getDpctNamespace() + + "blas_gemm::experimental::transform_desc_t::attribute::trans_b"}, }; ClassFieldMap = {}; @@ -1895,6 +1993,55 @@ void MapNames::setExplicitNamespaceMap() { {"cublasDtpsv_v2_64", "oneapi::mkl::blas::column_major::tpsv"}, {"cublasCtpsv_v2_64", "oneapi::mkl::blas::column_major::tpsv"}, {"cublasZtpsv_v2_64", "oneapi::mkl::blas::column_major::tpsv"}, + // cublasLt + {"cublasLtCreate", + "new " + getDpctNamespace() + "blas_gemm::experimental::descriptor"}, + {"cublasLtDestroy", + "delete " + getDpctNamespace() + "blas_gemm::experimental::descriptor"}, + {"cublasLtMatmulDescCreate", + "new " + getDpctNamespace() + "blas_gemm::experimental::matmul_desc_t"}, + {"cublasLtMatmulDescDestroy", + "delete " + getDpctNamespace() + + "blas_gemm::experimental::matmul_desc_t"}, + {"cublasLtMatmulDescSetAttribute", + getDpctNamespace() + + "blas_gemm::experimental::matmul_desc_t::set_attribute"}, + {"cublasLtMatmulDescGetAttribute", + getDpctNamespace() + + "blas_gemm::experimental::matmul_desc_t::get_attribute"}, + {"cublasLtMatrixLayoutCreate", + "new " + getDpctNamespace() + + "blas_gemm::experimental::matrix_layout_t"}, + {"cublasLtMatrixLayoutDestroy", + "delete " + getDpctNamespace() + + "blas_gemm::experimental::matrix_layout_t"}, + {"cublasLtMatrixLayoutSetAttribute", + getDpctNamespace() + + "blas_gemm::experimental::matrix_layout_t::set_attribute"}, + {"cublasLtMatrixLayoutGetAttribute", + getDpctNamespace() + + "blas_gemm::experimental::matrix_layout_t::get_attribute"}, + {"cublasLtMatmul", + getDpctNamespace() + "blas_gemm::experimental::matmul"}, + {"cublasLtMatmulPreferenceCreate", ""}, + {"cublasLtMatmulPreferenceDestroy", ""}, + {"cublasLtMatmulPreferenceSetAttribute", ""}, + {"cublasLtMatmulPreferenceGetAttribute", ""}, + {"cublasLtMatmulAlgoGetHeuristic", ""}, + {"cublasLtMatrixTransformDescCreate", + "new " + getDpctNamespace() + + "blas_gemm::experimental::transform_desc_t"}, + {"cublasLtMatrixTransformDescDestroy", + "delete" + getDpctNamespace() + + "blas_gemm::experimental::transform_desc_t"}, + {"cublasLtMatrixTransformDescSetAttribute", + getDpctNamespace() + + "blas_gemm::experimental::transform_desc_t::set_attribute"}, + {"cublasLtMatrixTransformDescGetAttribute", + getDpctNamespace() + + "blas_gemm::experimental::transform_desc_t::get_attribute"}, + {"cublasLtMatrixTransform", + getDpctNamespace() + "blas_gemm::experimental::matrix_transform"}, }; SOLVERAPIWithRewriter = {"cusolverDnSetAdvOptions", diff --git a/clang/runtime/dpct-rt/CMakeLists.txt b/clang/runtime/dpct-rt/CMakeLists.txt index 4b57a4541fc5..ed0cce9c14ee 100644 --- a/clang/runtime/dpct-rt/CMakeLists.txt +++ b/clang/runtime/dpct-rt/CMakeLists.txt @@ -18,6 +18,7 @@ set(dpct_rt_files include/dpct/fft_utils.hpp include/dpct/lapack_utils.hpp include/dpct/group_utils.hpp + include/dpct/blas_gemm_utils.hpp include/dpct/graph.hpp ) diff --git a/clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp b/clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp new file mode 100644 index 000000000000..c56801936c52 --- /dev/null +++ b/clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp @@ -0,0 +1,847 @@ +//==---- blas_gemm_utils.hpp ----------------------*- C++ -*----------------==// +// +// Copyright (C) Intel Corporation +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// See https://llvm.org/LICENSE.txt for license information. +// +//===----------------------------------------------------------------------===// +// This file contains the implementation of GEneral Matrix Multiplication by +// using oneDNN. More datatype combinations and the epilogue can be supported by +// oneDNN, which is not available in blas_utils.hpp using oneMKL. +//===----------------------------------------------------------------------===// + +#ifndef __DPCT_BLAS_GEMM_UTILS_HPP__ +#define __DPCT_BLAS_GEMM_UTILS_HPP__ + +#include "dnnl_utils.hpp" +#include "lib_common_utils.hpp" +#include "memory.hpp" + +#include +#include + +#include + +namespace dpct { +namespace blas_gemm { +namespace experimental { +enum class order_t : std::uint8_t { + col, + row, + col32, + col4_4r2_8c, + col32_2r_4r4 +}; +enum class pointer_mode_t { + host, + device, + device_vector, + alpha_device_vector_beta_zero, + alpha_device_vector_beta_host +}; + +class descriptor; +using descriptor_ptr = descriptor *; +class matrix_layout_t; +using matrix_layout_ptr = matrix_layout_t *; +class matmul_desc_t; +using matmul_desc_ptr = matmul_desc_t *; +class transform_desc_t; +using transform_desc_ptr = transform_desc_t *; + +class descriptor { +public: + descriptor() {} + void init(sycl::queue *q_ptr) { + _engine = ::dnnl::sycl_interop::make_engine(q_ptr->get_device(), + q_ptr->get_context()); + _engine_stream = ::dnnl::sycl_interop::make_stream(_engine, *q_ptr); + } + ::dnnl::engine get_engine() const noexcept { return _engine; } + ::dnnl::stream get_engine_stream() const noexcept { return _engine_stream; }; + +private: + ::dnnl::engine _engine; + ::dnnl::stream _engine_stream; +}; + +class matrix_layout_t { +public: + enum class attribute { type, order, rows, cols, ld }; + + matrix_layout_t(library_data_t type, std::uint64_t rows, std::uint64_t cols, + std::int64_t ld) + : _type(type), _rows(rows), _cols(cols), _ld(ld) {} + + void set_attribute(attribute attr, const void *mem) { + get_set_attr(attr, const_cast(mem)); + } + void get_attribute(attribute attr, void *mem) { + get_set_attr(attr, mem); + } + +private: + template void get_set_attr(attribute attr, void *mem) { +#define CASE(tag) \ + case attribute::tag: \ + if constexpr (is_set) { \ + _##tag = *static_cast(mem); \ + } else { \ + *static_cast(mem) = _##tag; \ + } \ + break; + switch (attr) { + CASE(type) + CASE(order) + CASE(rows) + CASE(cols) + CASE(ld) + } +#undef CASE + } + + library_data_t _type; + order_t _order = order_t::col; + std::uint64_t _rows; + std::uint64_t _cols; + std::int64_t _ld; + + friend sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr computeDesc, + const void *alpha, const void *a, + matrix_layout_ptr a_desc, const void *b, + matrix_layout_ptr b_desc, const void *beta, + const void *c, matrix_layout_ptr c_desc, void *d, + matrix_layout_ptr d_desc, dpct::queue_ptr q_ptr); + friend sycl::event + matrix_transform(transform_desc_ptr transform_desc, const void *alpha, + const void *a, matrix_layout_ptr a_desc, const void *beta, + const void *b, matrix_layout_ptr b_desc, void *c, + matrix_layout_ptr c_desc, queue_ptr q_ptr); +}; + +class matmul_desc_t { +public: + enum class attribute { + compute_type, + scale_type, + pointer_mode, + trans_a, + trans_b, + trans_c, + epilogue + }; + + matmul_desc_t(compute_type compute_type, library_data_t scale_type) + : _compute_type(compute_type), _scale_type(scale_type) {} + + void set_attribute(attribute attr, const void *mem) { + get_set_attr(attr, const_cast(mem)); + } + void get_attribute(attribute attr, void *mem) { + get_set_attr(attr, mem); + } + +private: + template void get_set_attr(attribute attr, void *mem) { +#define CASE(tag) \ + case attribute::tag: \ + if constexpr (is_set) { \ + _##tag = *static_cast(mem); \ + } else { \ + *static_cast(mem) = _##tag; \ + } \ + break; + switch (attr) { + CASE(compute_type) + CASE(scale_type) + CASE(pointer_mode) + CASE(trans_a) + CASE(trans_b) + CASE(trans_c) + CASE(epilogue) + } +#undef CASE + } + + compute_type _compute_type; + library_data_t _scale_type; + pointer_mode_t _pointer_mode = pointer_mode_t::host; + oneapi::mkl::transpose _trans_a = oneapi::mkl::transpose::nontrans; + oneapi::mkl::transpose _trans_b = oneapi::mkl::transpose::nontrans; + oneapi::mkl::transpose _trans_c = oneapi::mkl::transpose::nontrans; + int _epilogue = 1; + + friend sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr computeDesc, + const void *alpha, const void *a, + matrix_layout_ptr a_desc, const void *b, + matrix_layout_ptr b_desc, const void *beta, + const void *c, matrix_layout_ptr c_desc, void *d, + matrix_layout_ptr d_desc, dpct::queue_ptr q_ptr); +}; + +namespace detail { +/// Sacling each row of matrix D with the corresponding element of vector alpha. +template +sycl::event scale_d_with_vector_alpha_impl(queue_ptr q_ptr, int rows, int cols, + T *d, const Talpha *alpha, + std::vector deps) { + return q_ptr->submit([&](sycl::handler &cgh) { + cgh.depends_on(deps); +#ifdef DPCT_USM_LEVEL_NONE + access_wrapper d_acc(d, cgh); + access_wrapper alpha_acc(alpha, cgh); +#endif + cgh.parallel_for< + dpct_kernel_name>( + sycl::range<2>(rows, cols), [=](sycl::id<2> index) { +#ifdef DPCT_USM_LEVEL_NONE + auto d_data = d_acc.get_raw_pointer(); + auto alpha_data = alpha_acc.get_raw_pointer(); +#else + auto d_data = d; + auto alpha_data = alpha; +#endif + size_t row_idx = index.get(0); + size_t col_idx = index.get(1); + size_t idx = rows * col_idx + row_idx; + d_data[idx] = d_data[idx] * alpha_data[row_idx]; + }); + }); +} + +// d is col major without padding +inline sycl::event scale_d_with_vector_alpha(queue_ptr q_ptr, int rows, + int cols, void *d, + library_data_t d_type, + const void *alpha, + library_data_t alpha_type, + std::vector deps) { + std::uint64_t key = dpct::detail::get_type_combination_id(d_type, alpha_type); + sycl::event e; + switch (key) { + case dpct::detail::get_type_combination_id(library_data_t::real_int8, + library_data_t::real_float): { + e = scale_d_with_vector_alpha_impl( + q_ptr, rows, cols, (std::int8_t *)d, (const float *)alpha, deps); + break; + } + case dpct::detail::get_type_combination_id(library_data_t::real_int32, + library_data_t::real_float): { + e = scale_d_with_vector_alpha_impl(q_ptr, rows, cols, (int *)d, + (const float *)alpha, deps); + break; + } + default: + throw std::runtime_error("dpct::blas_gemm::experimental::detail::scale_d_" + "with_vector_alpha() does not support the data " + "type combination currently."); + } + return e; +} + +/// Get a linear idx map for a 2D point (row_idx, col_idx) between src_order and +/// dst_order. +inline std::tuple +get_linear_idx_map(size_t rows, size_t cols, size_t src_ld, order_t src_order, + size_t dst_ld, order_t dst_order, size_t row_idx, + size_t col_idx) { +#define COMBINE(from, to) \ + static_cast(from) << 8 | static_cast(to) + + size_t from_linear_idx, to_linear_idx; + switch (COMBINE(src_order, dst_order)) { + case COMBINE(order_t::col, order_t::row): { + from_linear_idx = src_ld * col_idx + row_idx; + to_linear_idx = dst_ld * row_idx + col_idx; + break; + } + case COMBINE(order_t::row, order_t::col): { + from_linear_idx = src_ld * row_idx + col_idx; + to_linear_idx = dst_ld * col_idx + row_idx; + break; + } + case COMBINE(order_t::col, order_t::col32): { + from_linear_idx = src_ld * col_idx + row_idx; + to_linear_idx = dst_ld * (col_idx / 32) + 32 * row_idx + col_idx % 32; + break; + } + case COMBINE(order_t::col32, order_t::col): { + from_linear_idx = src_ld * (col_idx / 32) + 32 * row_idx + col_idx % 32; + to_linear_idx = dst_ld * col_idx + row_idx; + break; + } + case COMBINE(order_t::col, order_t::col4_4r2_8c): { + from_linear_idx = src_ld * col_idx + row_idx; + + size_t from_row_in_row8_col32 = row_idx % 8; + size_t from_col_in_row8_col32 = col_idx % 32; + + size_t to_row_in_row8_col32 = + 4 * (from_row_in_row8_col32 % 2) + from_col_in_row8_col32 / 8; + size_t to_col_in_row8_col32 = 16 * ((from_col_in_row8_col32 / 4) % 2) + + 4 * (from_row_in_row8_col32 / 2) + + from_col_in_row8_col32 % 4; + size_t to_linear_idx_in_row8_col32 = + to_row_in_row8_col32 * 32 + to_col_in_row8_col32; + + to_linear_idx = dst_ld * (col_idx / 32) + (row_idx / 8) * (32 * 8) + + to_linear_idx_in_row8_col32; + break; + } + case COMBINE(order_t::col4_4r2_8c, order_t::col): { + to_linear_idx = dst_ld * col_idx + row_idx; + + size_t to_row_in_row8_col32 = row_idx % 8; + size_t to_col_in_row8_col32 = col_idx % 32; + + size_t from_row_in_row8_col32 = + 4 * (to_row_in_row8_col32 % 2) + to_col_in_row8_col32 / 8; + size_t from_col_in_row8_col32 = 16 * ((to_col_in_row8_col32 / 4) % 2) + + 4 * (to_row_in_row8_col32 / 2) + + to_col_in_row8_col32 % 4; + size_t from_linear_idx_in_row8_col32 = + from_row_in_row8_col32 * 32 + from_col_in_row8_col32; + + from_linear_idx = src_ld * (col_idx / 32) + (row_idx / 8) * (32 * 8) + + from_linear_idx_in_row8_col32; + break; + } + case COMBINE(order_t::col, order_t::col32_2r_4r4): { + from_linear_idx = src_ld * col_idx + row_idx; + + size_t from_row_in_row32_col32 = row_idx % 32; + size_t from_col_in_row32_col32 = col_idx % 32; + + size_t to_row_in_row32_col32 = 8 * ((from_row_in_row32_col32 % 8) / 2) + + (from_row_in_row32_col32 / 8) * 2 + + from_row_in_row32_col32 % 2; + size_t to_col_in_row32_col32 = from_col_in_row32_col32; + size_t to_linear_idx_in_row32_col32 = + to_row_in_row32_col32 * 32 + to_col_in_row32_col32; + + to_linear_idx = dst_ld * (col_idx / 32) + (row_idx / 32) * (32 * 32) + + to_linear_idx_in_row32_col32; + break; + } + case COMBINE(order_t::col32_2r_4r4, order_t::col): { + to_linear_idx = dst_ld * col_idx + row_idx; + + size_t to_row_in_row32_col32 = row_idx % 32; + size_t to_col_in_row32_col32 = col_idx % 32; + + size_t from_row_in_row32_col32 = 8 * ((to_row_in_row32_col32 % 8) / 2) + + (to_row_in_row32_col32 / 8) * 2 + + to_row_in_row32_col32 % 2; + size_t from_col_in_row32_col32 = to_col_in_row32_col32; + size_t from_linear_idx_in_row32_col32 = + from_row_in_row32_col32 * 32 + from_col_in_row32_col32; + + from_linear_idx = src_ld * (col_idx / 32) + (row_idx / 32) * (32 * 32) + + from_linear_idx_in_row32_col32; + break; + } + } +#undef COMBINE + return std::make_tuple(from_linear_idx, to_linear_idx); +} + +template +sycl::event matrix_transform(queue_ptr q_ptr, size_t rows, size_t cols, + size_t a_ld, order_t a_order, const T *a, + size_t c_ld, order_t c_order, T *c, + std::vector deps) { + if ((a_order != order_t::col && c_order != order_t::col) || + (a_order == order_t::col && c_order == order_t::col)) { + throw std::runtime_error( + "dpct::blas_gemm::experimental::detail::matrix_transform() does not " + "support the order combination currently."); + } + + return q_ptr->submit([&](sycl::handler &cgh) { + cgh.depends_on(deps); +#ifdef DPCT_USM_LEVEL_NONE + access_wrapper a_acc(a, cgh); + access_wrapper c_acc(c, cgh); +#endif + cgh.parallel_for>( + sycl::range<2>(a_ld, cols), [=](sycl::id<2> index) { +#ifdef DPCT_USM_LEVEL_NONE + auto a_data = a_acc.get_raw_pointer(); + auto c_data = c_acc.get_raw_pointer(); +#else + auto a_data = a; + auto c_data = c; +#endif + size_t row_idx = index.get(0); + size_t col_idx = index.get(1); + if (row_idx < rows) { + size_t from_linear_idx, to_linear_idx; + std::tie(from_linear_idx, to_linear_idx) = get_linear_idx_map( + rows, cols, a_ld, a_order, c_ld, c_order, row_idx, col_idx); + c_data[to_linear_idx] = a_data[from_linear_idx]; + } + }); + }); +} + +// Convert an integer to an float. +// The integer may on the host or the device, the float is on the device. +#ifdef DPCT_USM_LEVEL_NONE +inline void int2float(queue_ptr q_ptr, void *int_ptr, bool is_host_ptr, + sycl::buffer float_buffer) { + if (is_host_ptr) { + int alpha_host = *reinterpret_cast(int_ptr); + q_ptr->submit([&](sycl::handler &cgh) { + sycl::accessor float_acc(float_buffer, cgh, sycl::write_only, + sycl::no_init); + cgh.single_task>( + [=]() { float_acc[0] = alpha_host; }); + }); + } else { + q_ptr->submit([&](sycl::handler &cgh) { + access_wrapper int_acc(int_ptr, cgh); + sycl::accessor float_acc(float_buffer, cgh, sycl::write_only, + sycl::no_init); + cgh.single_task>([=]() { + auto int_data = int_acc.get_raw_pointer(); + float_acc[0] = int_data[0]; + }); + }); + } +} +#else +inline void int2float(queue_ptr q_ptr, void *int_ptr, bool is_host_ptr, + void *float_ptr) { + if (is_host_ptr) { + int alpha_host = *reinterpret_cast(int_ptr); + q_ptr->submit([&](sycl::handler &cgh) { + cgh.single_task>([=]() { + auto float_data = (float *)float_ptr; + float_data[0] = alpha_host; + }); + }); + } else { + q_ptr->submit([&](sycl::handler &cgh) { + cgh.single_task>([=]() { + auto int_data = (int *)int_ptr; + auto float_data = (float *)float_ptr; + float_data[0] = int_data[0]; + }); + }); + } +} +#endif +} // namespace detail + +/// This function does operation: D = alpha*(A*B) + beta*(C). +/// Currently supports type combinations: +/// scale_type==int32 && a_type==int8 && b_type==int8 && c_type==int32; +/// scale_type==float && a_type==int8 && b_type==int8 && c_type==int8; +/// scale_type==float && a_type==int8 && b_type==int8 && c_type==int32. +/// Currently it only supports beta==0. +/// NOTE: Non-col-major matrix will be converted to col-major matrix before. +/// TODO: Impl row-major matmul without layout conversion. +/// TODO: Impl epilogue for the matmul. +/// multiplication and converted back after multiplication. +/// \param [in] handle A handle containing context info. +/// \param [in] compute_desc Describe the computation. +/// \param [in] alpha Scaling factor alpha. +/// \param [in] a Input matrix A. +/// \param [in] a_desc Describe the matrix A. +/// \param [in] b Input matrix B. +/// \param [in] b_desc Describe the matrix B. +/// \param [in] beta Scaling factor beta. +/// \param [in] c Input matrix C. +/// \param [in] c_desc Describe the matrix C. +/// \param [out] d Output matrix D. +/// \param [in] d_desc Describe the matrix D. +/// \param [in] q_ptr The queue where the routine should be executed. +inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc, + const void *alpha, const void *a, + matrix_layout_ptr a_desc, const void *b, + matrix_layout_ptr b_desc, const void *beta, + const void *c, matrix_layout_ptr c_desc, void *d, + matrix_layout_ptr d_desc, dpct::queue_ptr q_ptr) { + if (!q_ptr) + q_ptr = &get_default_queue(); + handle->init(q_ptr); + bool vector_alpha = false; + if (compute_desc->_pointer_mode == pointer_mode_t::device_vector || + compute_desc->_pointer_mode == + pointer_mode_t::alpha_device_vector_beta_zero || + compute_desc->_pointer_mode == + pointer_mode_t::alpha_device_vector_beta_host) { + vector_alpha = true; + } + + if (beta != nullptr) { + size_t beta_size = + dpct::detail::library_data_size[static_cast( + compute_desc->_scale_type)] / + 8; + void *beta_host = std::malloc(beta_size); + void *beta_zero = std::malloc(beta_size); + std::memset(beta_zero, 0, beta_size); + q_ptr->memcpy(beta_host, beta, beta_size).wait(); + if (std::memcmp(beta_host, beta_zero, beta_size)) + throw std::runtime_error("dpct::blas_gemm::experimental::matmul() does " + "not support non-zero beta currently."); + } + + if (compute_desc->_epilogue != 1) { + throw std::runtime_error("dpct::blas_gemm::experimental::matmul() does " + "not support epilogue currently."); + } + + if (compute_desc->_trans_a != oneapi::mkl::transpose::nontrans) { + throw std::runtime_error("dpct::blas_gemm::experimental::matmul() only " + "supports non-transposed matrix A currently."); + } + if (compute_desc->_trans_b != oneapi::mkl::transpose::trans) { + throw std::runtime_error("dpct::blas_gemm::experimental::matmul() only " + "supports transposed matrix B currently."); + } + + if (!(compute_desc->_scale_type == library_data_t::real_int32 && + a_desc->_type == library_data_t::real_int8 && + b_desc->_type == library_data_t::real_int8 && + c_desc->_type == library_data_t::real_int32) && + !(compute_desc->_scale_type == library_data_t::real_float && + a_desc->_type == library_data_t::real_int8 && + b_desc->_type == library_data_t::real_int8 && + c_desc->_type == library_data_t::real_int8) && + !(compute_desc->_scale_type == library_data_t::real_float && + a_desc->_type == library_data_t::real_int8 && + b_desc->_type == library_data_t::real_int8 && + c_desc->_type == library_data_t::real_int32)) { + throw std::runtime_error( + "dpct::blas_gemm::experimental::matmul() only supports data type " + "combinataions:\n scale_type==int32 && a_type==int8 && b_type==int8 " + "&& c_type==int32,\n scale_type==float && a_type==int8 && " + "b_type==int8 && c_type==int8 or\n scale_type==float && a_type==int8 " + "&& b_type==int8 && c_type==int32."); + } + + // For non-col_major matrix, convert it to col_major. + const void *new_a = a; + const void *new_b = b; + void *new_d = d; + size_t new_lda = a_desc->_ld, new_ldb = b_desc->_ld, new_ldd = d_desc->_ld; + std::vector transform_events; + if (a_desc->_order != order_t::col) { + new_lda = a_desc->_rows; + if (a_desc->_type == library_data_t::real_int8) { + new_a = + dpct_malloc(sizeof(std::int8_t) * a_desc->_cols * new_lda, *q_ptr); + sycl::event e = detail::matrix_transform( + q_ptr, a_desc->_rows, a_desc->_cols, a_desc->_ld, a_desc->_order, + (const std::int8_t *)a, new_lda, order_t::col, (std::int8_t *)new_a, + {}); + transform_events.push_back(e); + } else { + new_a = dpct_malloc(sizeof(int) * a_desc->_cols * new_lda, *q_ptr); + sycl::event e = detail::matrix_transform( + q_ptr, a_desc->_rows, a_desc->_cols, a_desc->_ld, a_desc->_order, + (const int *)a, new_lda, order_t::col, (int *)new_a, {}); + transform_events.push_back(e); + } + } + if (b_desc->_order != order_t::col) { + new_ldb = b_desc->_rows; + if (b_desc->_type == library_data_t::real_int8) { + new_b = + dpct_malloc(sizeof(std::int8_t) * b_desc->_cols * new_ldb, *q_ptr); + sycl::event e = detail::matrix_transform( + q_ptr, b_desc->_rows, b_desc->_cols, b_desc->_ld, b_desc->_order, + (const std::int8_t *)b, new_ldb, order_t::col, (std::int8_t *)new_b, + {}); + transform_events.push_back(e); + } else { + new_b = dpct_malloc(sizeof(int) * b_desc->_cols * new_ldb, *q_ptr); + sycl::event e = detail::matrix_transform( + q_ptr, b_desc->_rows, b_desc->_cols, b_desc->_ld, b_desc->_order, + (const int *)b, new_ldb, order_t::col, (int *)new_b, {}); + transform_events.push_back(e); + } + } + if (d_desc->_order != order_t::col) { + new_ldd = d_desc->_rows; + if (d_desc->_type == library_data_t::real_int8) { + new_d = + dpct_malloc(sizeof(std::int8_t) * d_desc->_cols * new_ldd, *q_ptr); + } else { + new_d = dpct_malloc(sizeof(int) * d_desc->_cols * new_ldd, *q_ptr); + } + } + + // start to call oneDNN matmul primitive + // a,d are col_major, b is row_major + const size_t m = a_desc->_rows; + const size_t n = d_desc->_cols; + const size_t k = b_desc->_cols; + const ::dnnl::memory::dim M = m; + const ::dnnl::memory::dim N = n; + const ::dnnl::memory::dim K = k; + const library_data_t a_type = a_desc->_type; + const library_data_t b_type = b_desc->_type; + const library_data_t d_type = d_desc->_type; + const library_data_t scale_type = compute_desc->_scale_type; + + ::dnnl::memory::dims src_dims = {M, K}; + ::dnnl::memory::dims weights_dims = {K, N}; + ::dnnl::memory::dims dst_dims = {M, N}; + + const ::dnnl::memory::dims src_strides = + ::dnnl::memory::dims{1, static_cast(new_lda)}; + const ::dnnl::memory::dims weights_strides = + ::dnnl::memory::dims{static_cast(new_ldb), 1}; + const ::dnnl::memory::dims dst_strides = + ::dnnl::memory::dims{1, static_cast(new_ldd)}; + + auto src_md = ::dnnl::memory::desc( + src_dims, dpct::dnnl::memory_desc_ext::to_dnnl_data_type(a_type), + src_strides); + auto weights_md = ::dnnl::memory::desc( + weights_dims, dpct::dnnl::memory_desc_ext::to_dnnl_data_type(b_type), + weights_strides); + auto dst_md = ::dnnl::memory::desc( + dst_dims, dpct::dnnl::memory_desc_ext::to_dnnl_data_type(d_type), + dst_strides); + + auto *src_mem = + new ::dnnl::memory(src_md, handle->get_engine(), DNNL_MEMORY_NONE); + auto *weights_mem = + new ::dnnl::memory(weights_md, handle->get_engine(), DNNL_MEMORY_NONE); + auto *dst_mem = + new ::dnnl::memory(dst_md, handle->get_engine(), DNNL_MEMORY_NONE); + +#ifdef DPCT_USM_LEVEL_NONE +#define SET_BUFFER(DST, TYPE, SRC) \ + { \ + switch (TYPE) { \ + case library_data_t::real_int8: { \ + auto buf = get_buffer(SRC); \ + ::dnnl::sycl_interop::set_buffer(*DST, buf); \ + break; \ + } \ + case library_data_t::real_int32: { \ + auto buf = get_buffer(SRC); \ + ::dnnl::sycl_interop::set_buffer(*DST, buf); \ + break; \ + } \ + default: \ + throw std::runtime_error("dpct::blas_gemm::experimental::matmul() does " \ + "not support type (dpct::library_data_t) " + \ + std::to_string((std::uint8_t)TYPE) + \ + " currently."); \ + } \ + } + + SET_BUFFER(src_mem, a_type, new_a); + SET_BUFFER(weights_mem, b_type, new_b); + SET_BUFFER(dst_mem, d_type, new_d); +#undef SET_BUFFER +#else + src_mem->set_data_handle(const_cast(new_a)); + weights_mem->set_data_handle(const_cast(new_b)); + dst_mem->set_data_handle(new_d); +#endif + + std::unordered_map matmul_args; + matmul_args.insert({DNNL_ARG_SRC, *src_mem}); + matmul_args.insert({DNNL_ARG_WEIGHTS, *weights_mem}); + matmul_args.insert({DNNL_ARG_DST, *dst_mem}); + ::dnnl::primitive_attr matmul_attr; + ::dnnl::memory *scales_alpha = nullptr; + if (!vector_alpha) { + matmul_attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0); + scales_alpha = new ::dnnl::memory( + {{1}, ::dnnl::memory::data_type::f32, {1}}, handle->get_engine()); + if (scale_type != library_data_t::real_float) { +#ifdef DPCT_USM_LEVEL_NONE + *scales_alpha = ::dnnl::sycl_interop::make_memory( + {{1}, ::dnnl::memory::data_type::f32, {1}}, handle->get_engine(), + ::dnnl::sycl_interop::memory_kind::buffer); +#endif + detail::int2float( + q_ptr, const_cast(alpha), + compute_desc->_pointer_mode == pointer_mode_t::host, +#ifdef DPCT_USM_LEVEL_NONE + ::dnnl::sycl_interop::get_buffer(*scales_alpha) +#else + scales_alpha->get_data_handle() +#endif + ); + } else { + dpct::dpct_memcpy(scales_alpha->get_data_handle(), alpha, sizeof(float), + automatic, *q_ptr); + } + matmul_args.insert( + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, *scales_alpha}); + } + + auto matmul_pd = ::dnnl::matmul::primitive_desc( + handle->get_engine(), src_md, weights_md, dst_md, matmul_attr); + auto matmul_prim = ::dnnl::matmul(matmul_pd); + sycl::event matmul_prim_event = ::dnnl::sycl_interop::execute( + matmul_prim, handle->get_engine_stream(), matmul_args, transform_events); + + sycl::event scale_d_event; + if (vector_alpha) + scale_d_event = detail::scale_d_with_vector_alpha( + q_ptr, m, n, new_d, d_type, alpha, scale_type, {matmul_prim_event}); + // end of calling oneDNN + + sycl::event transform_d_event; + if (d_desc->_order != order_t::col) { + if (d_desc->_type == library_data_t::real_int8) { + transform_d_event = detail::matrix_transform( + q_ptr, d_desc->_rows, d_desc->_cols, new_ldd, order_t::col, + (const std::int8_t *)new_d, d_desc->_ld, d_desc->_order, + (std::int8_t *)d, {scale_d_event, matmul_prim_event}); + } else { + transform_d_event = detail::matrix_transform( + q_ptr, d_desc->_rows, d_desc->_cols, new_ldd, order_t::col, + (const int *)new_d, d_desc->_ld, d_desc->_order, (int *)d, + {scale_d_event, matmul_prim_event}); + } + } + + sycl::event free_event = q_ptr->submit([&](sycl::handler &cgh) { + cgh.depends_on({transform_d_event, scale_d_event, matmul_prim_event}); + cgh.host_task([=] { + delete src_mem; + delete weights_mem; + delete dst_mem; + if (!vector_alpha) + delete scales_alpha; + dpct::detail::dpct_free((void *)new_a, *q_ptr); + dpct::detail::dpct_free((void *)new_b, *q_ptr); + dpct::detail::dpct_free((void *)new_d, *q_ptr); + }); + }); + return free_event; +} + +class transform_desc_t { +public: + enum class attribute { scale_type, pointer_mode, trans_a, trans_b }; + + transform_desc_t(library_data_t scale_type) : _scale_type(scale_type) {} + void set_attribute(attribute attr, const void *mem) { + get_set_attr(attr, const_cast(mem)); + } + void get_attribute(attribute attr, void *mem) { + get_set_attr(attr, mem); + } + +private: + template void get_set_attr(attribute attr, void *mem) { +#define CASE(tag) \ + case attribute::tag: \ + if constexpr (is_set) { \ + _##tag = *static_cast(mem); \ + } else { \ + *static_cast(mem) = _##tag; \ + } \ + break; + switch (attr) { + CASE(scale_type) + CASE(pointer_mode) + CASE(trans_a) + CASE(trans_b) + } +#undef CASE + } + + library_data_t _scale_type; + pointer_mode_t _pointer_mode = pointer_mode_t::host; + oneapi::mkl::transpose _trans_a = oneapi::mkl::transpose::nontrans; + oneapi::mkl::transpose _trans_b = oneapi::mkl::transpose::nontrans; + + friend sycl::event + matrix_transform(transform_desc_ptr transform_desc, const void *alpha, + const void *a, matrix_layout_ptr a_desc, const void *beta, + const void *b, matrix_layout_ptr b_desc, void *c, + matrix_layout_ptr c_desc, queue_ptr q_ptr); +}; + +/// This function does operation: +/// C = alpha*transformation(A) + beta*transformation(B). +/// The "transformation" includes matrix transpose and layout conversion. +/// Currently it only supports alpha==1 && beta==0. +/// \param [in] transform_desc Describe the transformation. +/// \param [in] alpha Scaling factor alpha. +/// \param [in] a Input matrix A. +/// \param [in] a_desc Describe the matrix A. +/// \param [in] beta Scaling factor beta. +/// \param [in] b Input matrix B. +/// \param [in] b_desc Describe the matrix B. +/// \param [out] c Output matrix C. +/// \param [in] c_desc Describe the matrix C. +/// \param [in] q_ptr The queue where the routine should be executed. +inline sycl::event matrix_transform(transform_desc_ptr transform_desc, + const void *alpha, const void *a, + matrix_layout_ptr a_desc, const void *beta, + const void *b, matrix_layout_ptr b_desc, + void *c, matrix_layout_ptr c_desc, + queue_ptr q_ptr) { + if (!q_ptr) + q_ptr = &get_default_queue(); + + if (transform_desc->_pointer_mode != pointer_mode_t::host) { + throw std::runtime_error( + "dpct::blas_gemm::experimental::matrix_transform() " + "only supports pointer_mode_t::host as pointer_mode currently."); + } + if (transform_desc->_scale_type != library_data_t::real_float) { + throw std::runtime_error( + "dpct::blas_gemm::experimental::matrix_transform() " + "only supports library_data_t::real_float as scale_type currently."); + } + + if (alpha != nullptr) { + if (1.0f != *reinterpret_cast(alpha)) + throw std::runtime_error( + "dpct::blas_gemm::experimental::matrix_transform() does not " + "support non-one alpha currently."); + } + + if (beta != nullptr) { + if (0.0f != *reinterpret_cast(beta)) + throw std::runtime_error( + "dpct::blas_gemm::experimental::matrix_transform() does not " + "support non-zero beta currently."); + } + + if (b != nullptr) { + throw std::runtime_error( + "dpct::blas_gemm::experimental::matrix_transform() does not " + "support matrix B currently."); + } + + if ((a_desc->_type != library_data_t::real_int8 || + c_desc->_type != library_data_t::real_int8) && + (a_desc->_type != library_data_t::real_int32 || + c_desc->_type != library_data_t::real_int32)) { + throw std::runtime_error( + "dpct::blas_gemm::experimental::matrix_transform() only supports " + "combinations of data types: a_type==real_int8&&c_type==real_int8, " + "a_type==real_int32&&c_type==real_int32."); + } + + if (a_desc->_type == library_data_t::real_int8) { + return detail::matrix_transform( + q_ptr, a_desc->_rows, a_desc->_cols, a_desc->_ld, a_desc->_order, + (const std::int8_t *)a, c_desc->_ld, c_desc->_order, (std::int8_t *)c, + {}); + } else { + return detail::matrix_transform( + q_ptr, a_desc->_rows, a_desc->_cols, a_desc->_ld, a_desc->_order, + (const int *)a, c_desc->_ld, c_desc->_order, (int *)c, {}); + } +} +} // namespace experimental +} // namespace blas_gemm +} // namespace dpct +#endif // __DPCT_BLAS_GEMM_UTILS_HPP__ diff --git a/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp b/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp index 4dfc9600ae15..5420fc3a4c28 100644 --- a/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp +++ b/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp @@ -291,18 +291,6 @@ enum class math_mode : int { mm_default, mm_tf32, }; -enum class compute_type : int { - f16, - f16_standard, - f32, - f32_standard, - f32_fast_bf16, - f32_fast_tf32, - f64, - f64_standard, - i32, - i32_standard, -}; class descriptor { public: @@ -2467,7 +2455,6 @@ trmm(sycl::queue &q, oneapi::mkl::side left_right, blas::trmm(&desc, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, b, ldb, c, ldc); } - } // namespace dpct #undef DPCT_COMPUTE_MODE_ARG #undef DPCT_COMPUTE_MODE_PARAM diff --git a/clang/runtime/dpct-rt/include/dpct/lib_common_utils.hpp b/clang/runtime/dpct-rt/include/dpct/lib_common_utils.hpp index 03fd0c85df07..5dca30833c71 100644 --- a/clang/runtime/dpct-rt/include/dpct/lib_common_utils.hpp +++ b/clang/runtime/dpct-rt/include/dpct/lib_common_utils.hpp @@ -101,6 +101,19 @@ enum class library_data_t : unsigned char { library_data_t_size }; +enum class compute_type : int { + f16, + f16_standard, + f32, + f32_standard, + f32_fast_bf16, + f32_fast_tf32, + f64, + f64_standard, + i32, + i32_standard, +}; + namespace detail { template inline constexpr std::uint64_t get_type_combination_id(ArgT Val) { diff --git a/clang/test/dpct/check_header_files.cpp b/clang/test/dpct/check_header_files.cpp index f0f154dbcfd0..7e9fda92b7d7 100644 --- a/clang/test/dpct/check_header_files.cpp +++ b/clang/test/dpct/check_header_files.cpp @@ -86,6 +86,11 @@ // RUN: echo "end" >> %T/check_header_files/diff_res.txt // RUN: FileCheck %s --match-full-lines --input-file %T/check_header_files/diff_res.txt +// RUN: echo "begin" > %T/check_header_files/diff_res.txt +// RUN: diff %T/check_header_files/out/include/dpct/blas_gemm_utils.hpp %S/../../runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp >> %T/check_header_files/diff_res.txt +// RUN: echo "end" >> %T/check_header_files/diff_res.txt +// RUN: FileCheck %s --match-full-lines --input-file %T/check_header_files/diff_res.txt + // RUN: echo "begin" > %T/check_header_files/diff_res.txt // RUN: diff %T/check_header_files/out/include/dpct/dpl_extras/algorithm.h %S/../../runtime/dpct-rt/include/dpct/dpl_extras/algorithm.h >> %T/check_header_files/diff_res.txt // RUN: echo "end" >> %T/check_header_files/diff_res.txt diff --git a/clang/test/dpct/cublas-usm-11.cu b/clang/test/dpct/cublas-usm-11.cu index c04b7eb2532e..11522b38c4f3 100644 --- a/clang/test/dpct/cublas-usm-11.cu +++ b/clang/test/dpct/cublas-usm-11.cu @@ -19,9 +19,9 @@ void foo1() { //CHECK-NEXT:dpct::scal(handle->get_queue(), 4, alpha, dpct::library_data_t::real_float, x, dpct::library_data_t::real_float, 1); //CHECK-NEXT:dpct::axpy(handle->get_queue(), 4, alpha, dpct::library_data_t::real_float, x, dpct::library_data_t::real_float, 1, y, dpct::library_data_t::real_float, 1); //CHECK-NEXT:dpct::rot(handle->get_queue(), 4, x, dpct::library_data_t::real_float, 1, y, dpct::library_data_t::real_float, 1, cos, sin, dpct::library_data_t::real_float); - //CHECK-NEXT:dpct::blas::gemm(handle, oneapi::mkl::transpose::nontrans, oneapi::mkl::transpose::nontrans, 4, 4, 4, alpha, a, dpct::library_data_t::real_half, 4, b, dpct::library_data_t::real_half, 4, beta, c, dpct::library_data_t::real_half, 4, dpct::blas::compute_type::f16); - //CHECK-NEXT:dpct::blas::gemm_batch(handle, oneapi::mkl::transpose::nontrans, oneapi::mkl::transpose::nontrans, 4, 4, 4, alpha, a_array, dpct::library_data_t::real_half, 4, b_array, dpct::library_data_t::real_half, 4, beta, c_array, dpct::library_data_t::real_half, 4, 2, dpct::blas::compute_type::f16); - //CHECK-NEXT:dpct::blas::gemm_batch(handle, oneapi::mkl::transpose::nontrans, oneapi::mkl::transpose::nontrans, 4, 4, 4, alpha, a, dpct::library_data_t::real_half, 4, 16, b, dpct::library_data_t::real_half, 4, 16, beta, c, dpct::library_data_t::real_half, 4, 16, 2, dpct::blas::compute_type::f16); + //CHECK-NEXT:dpct::blas::gemm(handle, oneapi::mkl::transpose::nontrans, oneapi::mkl::transpose::nontrans, 4, 4, 4, alpha, a, dpct::library_data_t::real_half, 4, b, dpct::library_data_t::real_half, 4, beta, c, dpct::library_data_t::real_half, 4, dpct::compute_type::f16); + //CHECK-NEXT:dpct::blas::gemm_batch(handle, oneapi::mkl::transpose::nontrans, oneapi::mkl::transpose::nontrans, 4, 4, 4, alpha, a_array, dpct::library_data_t::real_half, 4, b_array, dpct::library_data_t::real_half, 4, beta, c_array, dpct::library_data_t::real_half, 4, 2, dpct::compute_type::f16); + //CHECK-NEXT:dpct::blas::gemm_batch(handle, oneapi::mkl::transpose::nontrans, oneapi::mkl::transpose::nontrans, 4, 4, 4, alpha, a, dpct::library_data_t::real_half, 4, 16, b, dpct::library_data_t::real_half, 4, 16, beta, c, dpct::library_data_t::real_half, 4, 16, 2, dpct::compute_type::f16); cublasNrm2Ex(handle, 4, x, CUDA_R_32F, 1, res, CUDA_R_32F, CUDA_R_32F); cublasDotEx(handle, 4, x, CUDA_R_32F, 1, y, CUDA_R_32F, 1, res, CUDA_R_32F, CUDA_R_32F); cublasDotcEx(handle, 4, x, CUDA_R_32F, 1, y, CUDA_R_32F, 1, res, CUDA_R_32F, CUDA_R_32F); @@ -40,7 +40,7 @@ void foo2() { void **b_array; void **c_array; - //CHECK:dpct::blas::gemm_batch(handle, oneapi::mkl::transpose::nontrans, oneapi::mkl::transpose::nontrans, 4, 4, 4, alpha, const_cast(a_array), dpct::library_data_t::real_half, 4, const_cast(b_array), dpct::library_data_t::real_half, 4, beta, c_array, dpct::library_data_t::real_half, 4, 2, dpct::blas::compute_type::f16); + //CHECK:dpct::blas::gemm_batch(handle, oneapi::mkl::transpose::nontrans, oneapi::mkl::transpose::nontrans, 4, 4, 4, alpha, const_cast(a_array), dpct::library_data_t::real_half, 4, const_cast(b_array), dpct::library_data_t::real_half, 4, beta, c_array, dpct::library_data_t::real_half, 4, 2, dpct::compute_type::f16); cublasGemmBatchedEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, 4, 4, 4, alpha, a_array, CUDA_R_16F, 4, b_array, CUDA_R_16F, 4, beta, c_array, CUDA_R_16F, 4, 2, CUBLAS_COMPUTE_16F, CUBLAS_GEMM_DEFAULT); } diff --git a/clang/test/dpct/cublaslt.cu b/clang/test/dpct/cublaslt.cu new file mode 100644 index 000000000000..f699ead6c7e3 --- /dev/null +++ b/clang/test/dpct/cublaslt.cu @@ -0,0 +1,275 @@ +// UNSUPPORTED: cuda-8.0, cuda-9.0, cuda-9.1, cuda-9.2, cuda-10.0, cuda-10.1, cuda-10.2, cuda-11.0, cuda-11.1, cuda-11.2, cuda-11.3, cuda-11.4, cuda-11.5 +// UNSUPPORTED: v8.0, v9.0, v9.1, v9.2, v10.0, v10.1, v10.2, v11.0, v11.1, v11.2, v11.3, v11.4, v11.5 +// RUN: dpct --format-range=none --out-root %T/cublaslt %s --cuda-include-path="%cuda-path/include" +// RUN: FileCheck --input-file %T/cublaslt/cublaslt.dp.cpp --match-full-lines %s +// RUN: %if build_lit %{icpx -c -fsycl %T/cublaslt/cublaslt.dp.cpp -o %T/cublaslt/cublaslt.dp.o %} + +#include "cublasLt.h" + +void foo1 () { + // CHECK: dpct::blas_gemm::experimental::descriptor_ptr ltHandle; + // CHECK-NEXT: ltHandle = new dpct::blas_gemm::experimental::descriptor(); + // CHECK-NEXT: delete (ltHandle); + cublasLtHandle_t ltHandle; + cublasLtCreate(<Handle); + cublasLtDestroy(ltHandle); + + // CHECK: dpct::blas_gemm::experimental::matrix_layout_ptr matLayout; + // CHECK-NEXT: dpct::library_data_t type; + // CHECK-NEXT: uint64_t rows; + // CHECK-NEXT: uint64_t cols; + // CHECK-NEXT: int64_t ld; + // CHECK-NEXT: matLayout = new dpct::blas_gemm::experimental::matrix_layout_t(type, rows, cols, ld); + cublasLtMatrixLayout_t matLayout; + cudaDataType type; + uint64_t rows; + uint64_t cols; + int64_t ld; + cublasLtMatrixLayoutCreate(&matLayout, type, rows, cols, ld); + + // CHECK: dpct::blas_gemm::experimental::matrix_layout_t::attribute attr1; + // CHECK-NEXT: void *buf1; + // CHECK-NEXT: size_t sizeInBytes1; + // CHECK-NEXT: size_t *sizeWritten1; + // CHECK-NEXT: matLayout->get_attribute(attr1, buf1); + // CHECK-NEXT: matLayout->set_attribute(attr1, buf1); + // CHECK-NEXT: delete (matLayout); + cublasLtMatrixLayoutAttribute_t attr1; + void *buf1; + size_t sizeInBytes1; + size_t *sizeWritten1; + cublasLtMatrixLayoutGetAttribute(matLayout, attr1, buf1, sizeInBytes1, sizeWritten1); + cublasLtMatrixLayoutSetAttribute(matLayout, attr1, buf1, sizeInBytes1); + cublasLtMatrixLayoutDestroy(matLayout); + + // CHECK: dpct::blas_gemm::experimental::matmul_desc_ptr matmulDesc; + // CHECK-NEXT: dpct::compute_type computeType; + // CHECK-NEXT: dpct::library_data_t scaleType; + // CHECK-NEXT: matmulDesc = new dpct::blas_gemm::experimental::matmul_desc_t(computeType, scaleType); + cublasLtMatmulDesc_t matmulDesc; + cublasComputeType_t computeType; + cudaDataType_t scaleType; + cublasLtMatmulDescCreate(&matmulDesc, computeType, scaleType); + + // CHECK: dpct::blas_gemm::experimental::matmul_desc_t::attribute attr2; + // CHECK-NEXT: void *buf2; + // CHECK-NEXT: size_t sizeInBytes2; + // CHECK-NEXT: size_t *sizeWritten2; + // CHECK-NEXT: matmulDesc->get_attribute(attr2, buf2); + // CHECK-NEXT: matmulDesc->set_attribute(attr2, buf2); + // CHECK-NEXT: delete (matmulDesc); + cublasLtMatmulDescAttributes_t attr2; + void *buf2; + size_t sizeInBytes2; + size_t *sizeWritten2; + cublasLtMatmulDescGetAttribute(matmulDesc, attr2, buf2, sizeInBytes2, sizeWritten2); + cublasLtMatmulDescSetAttribute(matmulDesc, attr2, buf2, sizeInBytes2); + cublasLtMatmulDescDestroy(matmulDesc); + + // CHECK: int matmulPreference; + // CHECK-NEXT: /* + // CHECK-NEXT: DPCT1026:{{[0-9]+}}: The call to cublasLtMatmulPreferenceCreate was removed because this functionality is redundant in SYCL. + // CHECK-NEXT: */ + // CHECK-NEXT: void *buf3; + // CHECK-NEXT: size_t sizeInBytes3; + // CHECK-NEXT: size_t *sizeWritten3; + // CHECK-NEXT: /* + // CHECK-NEXT: DPCT1026:{{[0-9]+}}: The call to cublasLtMatmulPreferenceGetAttribute was removed because this functionality is redundant in SYCL. + // CHECK-NEXT: */ + // CHECK-NEXT: /* + // CHECK-NEXT: DPCT1026:{{[0-9]+}}: The call to cublasLtMatmulPreferenceSetAttribute was removed because this functionality is redundant in SYCL. + // CHECK-NEXT: */ + // CHECK-NEXT: /* + // CHECK-NEXT: DPCT1026:{{[0-9]+}}: The call to cublasLtMatmulPreferenceDestroy was removed because this functionality is redundant in SYCL. + // CHECK-NEXT: */ + cublasLtMatmulPreference_t matmulPreference; + cublasLtMatmulPreferenceCreate(&matmulPreference); + void *buf3; + size_t sizeInBytes3; + size_t *sizeWritten3; + cublasLtMatmulPreferenceGetAttribute(matmulPreference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, buf3, sizeInBytes3, sizeWritten3); + cublasLtMatmulPreferenceSetAttribute(matmulPreference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, buf3, sizeInBytes3); + cublasLtMatmulPreferenceDestroy(matmulPreference); + + cublasLtMatrixLayout_t Adesc; + cublasLtMatrixLayout_t Bdesc; + cublasLtMatrixLayout_t Cdesc; + cublasLtMatrixLayout_t Ddesc; + + // CHECK: int requestedAlgoCount = 1; + // CHECK-NEXT: int heuristicResultsArray; + // CHECK-NEXT: int returnAlgoCount; + // CHECK-NEXT: returnAlgoCount = 1; + int requestedAlgoCount = 1; + cublasLtMatmulHeuristicResult_t heuristicResultsArray; + int returnAlgoCount; + cublasLtMatmulAlgoGetHeuristic(ltHandle, matmulDesc, Adesc, Bdesc, Cdesc, Ddesc, matmulPreference, requestedAlgoCount, &heuristicResultsArray, &returnAlgoCount); +} + +void foo2() { + // CHECK: dpct::blas_gemm::experimental::descriptor_ptr lightHandle; + // CHECK-NEXT: dpct::blas_gemm::experimental::matmul_desc_ptr computeDesc; + // CHECK-NEXT: const void *alpha; + // CHECK-NEXT: const void *A; + // CHECK-NEXT: dpct::blas_gemm::experimental::matrix_layout_ptr Adesc; + // CHECK-NEXT: const void *B; + // CHECK-NEXT: dpct::blas_gemm::experimental::matrix_layout_ptr Bdesc; + // CHECK-NEXT: const void *beta; + // CHECK-NEXT: const void *C; + // CHECK-NEXT: dpct::blas_gemm::experimental::matrix_layout_ptr Cdesc; + // CHECK-NEXT: void *D; + // CHECK-NEXT: dpct::blas_gemm::experimental::matrix_layout_ptr Ddesc; + // CHECK-NEXT: const int *algo; + // CHECK-NEXT: void *workspace; + // CHECK-NEXT: size_t workspaceSizeInBytes; + // CHECK-NEXT: dpct::queue_ptr stream; + // CHECK-NEXT: dpct::blas_gemm::experimental::matmul(lightHandle, computeDesc, alpha, A, Adesc, B, Bdesc, beta, C, Cdesc, D, Ddesc, stream); + cublasLtHandle_t lightHandle; + cublasLtMatmulDesc_t computeDesc; + const void *alpha; + const void *A; + cublasLtMatrixLayout_t Adesc; + const void *B; + cublasLtMatrixLayout_t Bdesc; + const void *beta; + const void *C; + cublasLtMatrixLayout_t Cdesc; + void *D; + cublasLtMatrixLayout_t Ddesc; + const cublasLtMatmulAlgo_t *algo; + void *workspace; + size_t workspaceSizeInBytes; + cudaStream_t stream; + cublasLtMatmul(lightHandle, computeDesc, alpha, A, Adesc, B, Bdesc, beta, C, Cdesc, D, Ddesc, algo, workspace, workspaceSizeInBytes, stream); +} + +void foo3() { + // CHECK: dpct::blas_gemm::experimental::order_t a; + // CHECK-NEXT: a = dpct::blas_gemm::experimental::order_t::col; + // CHECK-NEXT: a = dpct::blas_gemm::experimental::order_t::row; + // CHECK-NEXT: a = dpct::blas_gemm::experimental::order_t::col32; + // CHECK-NEXT: a = dpct::blas_gemm::experimental::order_t::col4_4r2_8c; + // CHECK-NEXT: a = dpct::blas_gemm::experimental::order_t::col32_2r_4r4; + cublasLtOrder_t a; + a = CUBLASLT_ORDER_COL; + a = CUBLASLT_ORDER_ROW; + a = CUBLASLT_ORDER_COL32; + a = CUBLASLT_ORDER_COL4_4R2_8C; + a = CUBLASLT_ORDER_COL32_2R_4R4; + // CHECK: dpct::blas_gemm::experimental::pointer_mode_t b; + // CHECK-NEXT: b = dpct::blas_gemm::experimental::pointer_mode_t::host; + // CHECK-NEXT: b = dpct::blas_gemm::experimental::pointer_mode_t::device; + // CHECK-NEXT: b = dpct::blas_gemm::experimental::pointer_mode_t::device_vector; + // CHECK-NEXT: b = dpct::blas_gemm::experimental::pointer_mode_t::alpha_device_vector_beta_zero; + // CHECK-NEXT: b = dpct::blas_gemm::experimental::pointer_mode_t::alpha_device_vector_beta_host; + cublasLtPointerMode_t b; + b = CUBLASLT_POINTER_MODE_HOST; + b = CUBLASLT_POINTER_MODE_DEVICE; + b = CUBLASLT_POINTER_MODE_DEVICE_VECTOR; + b = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO; + b = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST; + // CHECK: dpct::blas_gemm::experimental::matrix_layout_t::attribute c; + // CHECK-NEXT: c = dpct::blas_gemm::experimental::matrix_layout_t::attribute::type; + // CHECK-NEXT: c = dpct::blas_gemm::experimental::matrix_layout_t::attribute::order; + // CHECK-NEXT: c = dpct::blas_gemm::experimental::matrix_layout_t::attribute::rows; + // CHECK-NEXT: c = dpct::blas_gemm::experimental::matrix_layout_t::attribute::cols; + // CHECK-NEXT: c = dpct::blas_gemm::experimental::matrix_layout_t::attribute::ld; + cublasLtMatrixLayoutAttribute_t c; + c = CUBLASLT_MATRIX_LAYOUT_TYPE; + c = CUBLASLT_MATRIX_LAYOUT_ORDER; + c = CUBLASLT_MATRIX_LAYOUT_ROWS; + c = CUBLASLT_MATRIX_LAYOUT_COLS; + c = CUBLASLT_MATRIX_LAYOUT_LD; + // CHECK: dpct::blas_gemm::experimental::matmul_desc_t::attribute d; + // CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::compute_type; + // CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::scale_type; + // CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::pointer_mode; + // CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::trans_a; + // CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::trans_b; + // CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::trans_c; + // CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::epilogue; + cublasLtMatmulDescAttributes_t d; + d = CUBLASLT_MATMUL_DESC_COMPUTE_TYPE; + d = CUBLASLT_MATMUL_DESC_SCALE_TYPE; + d = CUBLASLT_MATMUL_DESC_POINTER_MODE; + d = CUBLASLT_MATMUL_DESC_TRANSA; + d = CUBLASLT_MATMUL_DESC_TRANSB; + d = CUBLASLT_MATMUL_DESC_TRANSC; + d = CUBLASLT_MATMUL_DESC_EPILOGUE; + // CHECK: int e; + // CHECK-NEXT: e = 1; + // CHECK-NEXT: e = 2; + // CHECK-NEXT: e = 130; + // CHECK-NEXT: e = 4; + // CHECK-NEXT: e = 6; + // CHECK-NEXT: e = 134; + // CHECK-NEXT: e = 136; + // CHECK-NEXT: e = 152; + // CHECK-NEXT: e = 32; + // CHECK-NEXT: e = 160; + // CHECK-NEXT: e = 36; + // CHECK-NEXT: e = 164; + // CHECK-NEXT: e = 192; + // CHECK-NEXT: e = 208; + // CHECK-NEXT: e = 256; + // CHECK-NEXT: e = 512; + cublasLtEpilogue_t e; + e = CUBLASLT_EPILOGUE_DEFAULT; + e = CUBLASLT_EPILOGUE_RELU; + e = CUBLASLT_EPILOGUE_RELU_AUX; + e = CUBLASLT_EPILOGUE_BIAS; + e = CUBLASLT_EPILOGUE_RELU_BIAS; + e = CUBLASLT_EPILOGUE_RELU_AUX_BIAS; + e = CUBLASLT_EPILOGUE_DRELU; + e = CUBLASLT_EPILOGUE_DRELU_BGRAD; + e = CUBLASLT_EPILOGUE_GELU; + e = CUBLASLT_EPILOGUE_GELU_AUX; + e = CUBLASLT_EPILOGUE_GELU_BIAS; + e = CUBLASLT_EPILOGUE_GELU_AUX_BIAS; + e = CUBLASLT_EPILOGUE_DGELU; + e = CUBLASLT_EPILOGUE_DGELU_BGRAD; + e = CUBLASLT_EPILOGUE_BGRADA; + e = CUBLASLT_EPILOGUE_BGRADB; +} + +void foo4() { + // CHECK: dpct::blas_gemm::experimental::transform_desc_ptr transformDesc; + // CHECK-NEXT: dpct::library_data_t scaleType; + // CHECK-NEXT: transformDesc = new dpct::blas_gemm::experimental::transform_desc_t(scaleType); + // CHECK-NEXT: oneapi::mkl::transpose opT = oneapi::mkl::transpose::trans; + // CHECK-NEXT: size_t sizeWritten; + // CHECK-NEXT: transformDesc->set_attribute(dpct::blas_gemm::experimental::transform_desc_t::attribute::trans_a, &opT); + // CHECK-NEXT: transformDesc->get_attribute(dpct::blas_gemm::experimental::transform_desc_t::attribute::trans_a, &opT); + // CHECK-NEXT: delete (transformDesc); + cublasLtMatrixTransformDesc_t transformDesc; + cudaDataType scaleType; + cublasLtMatrixTransformDescCreate(&transformDesc, scaleType); + cublasOperation_t opT = CUBLAS_OP_T; + size_t sizeWritten; + cublasLtMatrixTransformDescSetAttribute(transformDesc, CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA, &opT, sizeof(opT)); + cublasLtMatrixTransformDescGetAttribute(transformDesc, CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA, &opT, sizeof(opT), &sizeWritten); + cublasLtMatrixTransformDescDestroy(transformDesc); + + // CHECK: dpct::blas_gemm::experimental::descriptor_ptr lightHandle; + // CHECK-NEXT: const void *alpha; + // CHECK-NEXT: const void *A; + // CHECK-NEXT: dpct::blas_gemm::experimental::matrix_layout_ptr Adesc; + // CHECK-NEXT: const void *beta; + // CHECK-NEXT: const void *B; + // CHECK-NEXT: dpct::blas_gemm::experimental::matrix_layout_ptr Bdesc; + // CHECK-NEXT: void *C; + // CHECK-NEXT: dpct::blas_gemm::experimental::matrix_layout_ptr Cdesc; + // CHECK-NEXT: dpct::queue_ptr stream; + // CHECK-NEXT: dpct::blas_gemm::experimental::matrix_transform(transformDesc, alpha, A, Adesc, beta, B, Bdesc, C, Cdesc, stream); + cublasLtHandle_t lightHandle; + const void *alpha; + const void *A; + cublasLtMatrixLayout_t Adesc; + const void *beta; + const void *B; + cublasLtMatrixLayout_t Bdesc; + void *C; + cublasLtMatrixLayout_t Cdesc; + cudaStream_t stream; + cublasLtMatrixTransform(lightHandle, transformDesc, alpha, A, Adesc, beta, B, Bdesc, C, Cdesc, stream); +} diff --git a/clang/test/dpct/types007.cu b/clang/test/dpct/types007.cu index 09ab927fbbff..d84954ee7140 100644 --- a/clang/test/dpct/types007.cu +++ b/clang/test/dpct/types007.cu @@ -19,8 +19,8 @@ int main(int argc, char **argv) { thrust::optional c = 1; } -// CHECK: void foo_1(dpct::blas::compute_type a1) { -// CHECK-NEXT: dpct::blas::compute_type b1 = a1; +// CHECK: void foo_1(dpct::compute_type a1) { +// CHECK-NEXT: dpct::compute_type b1 = a1; // CHECK-NEXT: } void foo_1(cublasComputeType_t a1) { cublasComputeType_t b1 = a1;