diff --git a/compiler/bindings/c/iree/compiler/dialects/iree_codegen.h b/compiler/bindings/c/iree/compiler/dialects/iree_codegen.h index 029e67277213..3adbc3eeef51 100644 --- a/compiler/bindings/c/iree/compiler/dialects/iree_codegen.h +++ b/compiler/bindings/c/iree/compiler/dialects/iree_codegen.h @@ -66,6 +66,14 @@ MLIR_CAPI_EXPORTED MlirAttribute ireeCodegenCompilationInfoAttrGet( MLIR_CAPI_EXPORTED ireeCodegenCompilationInfoParameters ireeCodegenCompilationInfoAttrGetParameters(MlirAttribute attr); +MLIR_CAPI_EXPORTED void +ireeCodegenGetExecutableVariantOps(MlirModule module, size_t *numOps, + MlirOperation *executableOps); + +MLIR_CAPI_EXPORTED void ireeCodegenQueryMMAIntrinsics(MlirOperation op, + size_t *numIntrinsics, + uint32_t *mmaIntrinsics); + #ifdef __cplusplus } #endif diff --git a/compiler/bindings/python/IREECompilerDialectsModule.cpp b/compiler/bindings/python/IREECompilerDialectsModule.cpp index 7ece224519c2..dec10fb033df 100644 --- a/compiler/bindings/python/IREECompilerDialectsModule.cpp +++ b/compiler/bindings/python/IREECompilerDialectsModule.cpp @@ -21,6 +21,33 @@ static const char *kGpuModuleImportPath = namespace py = pybind11; using namespace mlir::python::adaptors; +static std::vector +ireeCodegenGetExecutableVariantOpsBinding(MlirModule module) { + size_t numOps = 0; + ireeCodegenGetExecutableVariantOps(module, &numOps, nullptr); + std::vector ops(numOps); + ireeCodegenGetExecutableVariantOps(module, &numOps, ops.data()); + + return ops; +} + +static std::vector +ireeCodegenQueryMMAIntrinsicsBinding(MlirOperation op) { + size_t numMMAs = 0; + ireeCodegenQueryMMAIntrinsics(op, &numMMAs, nullptr); + std::vector mmaIntrinsics(numMMAs); + ireeCodegenQueryMMAIntrinsics(op, &numMMAs, mmaIntrinsics.data()); + + py::object mmaIntrinsicEnum = + py::module_::import(kGpuModuleImportPath).attr("MMAIntrinsic"); + std::vector mmaList(numMMAs); + for (size_t i = 0; i < numMMAs; ++i) { + mmaList[i] = mmaIntrinsicEnum(mmaIntrinsics[i]); + } + + return mmaList; +} + PYBIND11_MODULE(_ireeCompilerDialects, m) { m.doc() = "iree-compiler dialects python extension"; @@ -326,4 +353,22 @@ PYBIND11_MODULE(_ireeCompilerDialects, m) { "Gets an #iree_gpu.lowering_config from parameters.") .def_property_readonly("attributes", ireeGPULoweringConfigAttrGetAttributes); + + //===-------------------------------------------------------------------===// + // Binding to utility function getExecutableVariantOps + //===-------------------------------------------------------------------===// + + iree_codegen_module.def( + "get_executable_variant_ops", &ireeCodegenGetExecutableVariantOpsBinding, + "Gets the executable variant operations from a module.", + py::arg("module")); + + //===-------------------------------------------------------------------===// + // Binding to utility function queryMMAIntrinsics + //===-------------------------------------------------------------------===// + + iree_codegen_module.def( + "query_mma_intrinsics", &ireeCodegenQueryMMAIntrinsicsBinding, + "Queries the MMA intrinsics from an executable variant op.", + py::arg("op")); } diff --git a/compiler/src/iree/compiler/API/Internal/BUILD.bazel b/compiler/src/iree/compiler/API/Internal/BUILD.bazel index 9b6d2b8fcca9..2d67e440e014 100644 --- a/compiler/src/iree/compiler/API/Internal/BUILD.bazel +++ b/compiler/src/iree/compiler/API/Internal/BUILD.bazel @@ -137,6 +137,7 @@ iree_compiler_cc_library( deps = [ "//compiler/bindings/c:headers", "//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect", + "//compiler/src/iree/compiler/Codegen/Utils", "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:CAPIIRHeaders", "@llvm-project//mlir:IR", diff --git a/compiler/src/iree/compiler/API/Internal/CMakeLists.txt b/compiler/src/iree/compiler/API/Internal/CMakeLists.txt index e0ec31da4ac3..871fbf35afa9 100644 --- a/compiler/src/iree/compiler/API/Internal/CMakeLists.txt +++ b/compiler/src/iree/compiler/API/Internal/CMakeLists.txt @@ -116,6 +116,7 @@ iree_cc_library( MLIRCAPIIR MLIRIR iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect + iree::compiler::Codegen::Utils iree::compiler::bindings::c::headers PUBLIC ) diff --git a/compiler/src/iree/compiler/API/Internal/IREECodegenDialectCAPI.cpp b/compiler/src/iree/compiler/API/Internal/IREECodegenDialectCAPI.cpp index c295d48b01e3..82e24960bcf7 100644 --- a/compiler/src/iree/compiler/API/Internal/IREECodegenDialectCAPI.cpp +++ b/compiler/src/iree/compiler/API/Internal/IREECodegenDialectCAPI.cpp @@ -10,6 +10,7 @@ #include #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h" +#include "iree/compiler/Codegen/Utils/GPUUtils.h" #include "iree/compiler/dialects/iree_codegen.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/IR.h" @@ -24,6 +25,8 @@ using mlir::iree_compiler::IREE::Codegen::DispatchLoweringPassPipeline; using mlir::iree_compiler::IREE::Codegen::DispatchLoweringPassPipelineAttr; using mlir::iree_compiler::IREE::Codegen::LoweringConfigAttrInterface; using mlir::iree_compiler::IREE::Codegen::TranslationInfoAttr; +using mlir::iree_compiler::IREE::GPU::MMAIntrinsic; +using mlir::iree_compiler::IREE::HAL::ExecutableVariantOp; bool ireeAttributeIsACodegenDispatchLoweringPassPipelineAttr( MlirAttribute attr) { @@ -149,3 +152,49 @@ ireeCodegenCompilationInfoAttrGetParameters(MlirAttribute attr) { parameters.translationInfo = wrap(compilationInfo.getTranslationInfo()); return parameters; } + +void ireeCodegenGetExecutableVariantOps(MlirModule module, size_t *numOps, + MlirOperation *executableOps) { + assert(!mlirModuleIsNull(module) && "module cannot be nullptr"); + assert(numOps && "numOps cannot be nullptr"); + + mlir::ModuleOp moduleOp = unwrap(module); + llvm::SmallVector executableVariantOps = + mlir::iree_compiler::getExecutableVariantOps(moduleOp); + + if (!executableOps) { + *numOps = executableVariantOps.size(); + return; + } + + assert( + *numOps == executableVariantOps.size() && + "*numOps must match the number of elements in the executableVariantOps"); + + for (size_t i = 0, e = executableVariantOps.size(); i < e; ++i) { + executableOps[i] = wrap(executableVariantOps[i]); + } +} + +void ireeCodegenQueryMMAIntrinsics(MlirOperation op, size_t *numIntrinsics, + uint32_t *mmaIntrinsics) { + assert(numIntrinsics && "numIntrinsics cannot be nullptr"); + + mlir::Operation *mlirOp = unwrap(op); + auto variantOp = llvm::dyn_cast_if_present(mlirOp); + assert(variantOp && "operation is not a ExecutableVariantOp"); + + llvm::SmallVector intrinsics = + mlir::iree_compiler::queryMMAIntrinsics(variantOp); + if (!mmaIntrinsics) { + *numIntrinsics = intrinsics.size(); + return; + } + + assert(*numIntrinsics == intrinsics.size() && + "*numIntrinsics must match the number of elements in the intrinsics"); + + for (size_t i = 0, e = intrinsics.size(); i < e; ++i) { + mmaIntrinsics[i] = static_cast(intrinsics[i]); + } +} diff --git a/compiler/src/iree/compiler/API/api_exports.c b/compiler/src/iree/compiler/API/api_exports.c index ffb8f086c678..7f1b55044a8d 100644 --- a/compiler/src/iree/compiler/API/api_exports.c +++ b/compiler/src/iree/compiler/API/api_exports.c @@ -24,6 +24,8 @@ extern void ireeCodegenCompilationInfoAttrGetTypeID(); extern void ireeCodegenDispatchLoweringPassPipelineAttrGet(); extern void ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID(); extern void ireeCodegenDispatchLoweringPassPipelineAttrGetValue(); +extern void ireeCodegenGetExecutableVariantOps(); +extern void ireeCodegenQueryMMAIntrinsics(); extern void ireeCodegenTranslationInfoAttrGet(); extern void ireeCodegenTranslationInfoAttrGetParameters(); extern void ireeCodegenTranslationInfoAttrGetTypeID(); diff --git a/compiler/src/iree/compiler/API/api_exports.def b/compiler/src/iree/compiler/API/api_exports.def index ed5e12cceb48..2280a69d2912 100644 --- a/compiler/src/iree/compiler/API/api_exports.def +++ b/compiler/src/iree/compiler/API/api_exports.def @@ -14,6 +14,8 @@ EXPORTS ireeCodegenDispatchLoweringPassPipelineAttrGet ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID ireeCodegenDispatchLoweringPassPipelineAttrGetValue + ireeCodegenGetExecutableVariantOps + ireeCodegenQueryMMAIntrinsics ireeCodegenTranslationInfoAttrGet ireeCodegenTranslationInfoAttrGetParameters ireeCodegenTranslationInfoAttrGetTypeID diff --git a/compiler/src/iree/compiler/API/api_exports.ld b/compiler/src/iree/compiler/API/api_exports.ld index 0808927de527..5bd3b256d83f 100644 --- a/compiler/src/iree/compiler/API/api_exports.ld +++ b/compiler/src/iree/compiler/API/api_exports.ld @@ -15,6 +15,8 @@ VER_0 { ireeCodegenDispatchLoweringPassPipelineAttrGet; ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID; ireeCodegenDispatchLoweringPassPipelineAttrGetValue; + ireeCodegenGetExecutableVariantOps; + ireeCodegenQueryMMAIntrinsics; ireeCodegenTranslationInfoAttrGet; ireeCodegenTranslationInfoAttrGetParameters; ireeCodegenTranslationInfoAttrGetTypeID; diff --git a/compiler/src/iree/compiler/API/api_exports.macos.lst b/compiler/src/iree/compiler/API/api_exports.macos.lst index 11169bf3f13d..f92e98f299d2 100644 --- a/compiler/src/iree/compiler/API/api_exports.macos.lst +++ b/compiler/src/iree/compiler/API/api_exports.macos.lst @@ -13,6 +13,8 @@ _ireeCodegenCompilationInfoAttrGetTypeID _ireeCodegenDispatchLoweringPassPipelineAttrGet _ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID _ireeCodegenDispatchLoweringPassPipelineAttrGetValue +_ireeCodegenGetExecutableVariantOps +_ireeCodegenQueryMMAIntrinsics _ireeCodegenTranslationInfoAttrGet _ireeCodegenTranslationInfoAttrGetParameters _ireeCodegenTranslationInfoAttrGetTypeID diff --git a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp index f1eb77677205..612183d94eda 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp @@ -1030,7 +1030,7 @@ std::optional getGPUSubgroupSize(mlir::FunctionOpInterface func) { SmallVector getExecutableVariantOps(mlir::ModuleOp moduleOp) { - llvm::SmallVector executableVariantOps; + SmallVector executableVariantOps; moduleOp.walk([&](IREE::HAL::ExecutableVariantOp executableOp) { executableVariantOps.push_back(executableOp); }); @@ -1039,7 +1039,7 @@ getExecutableVariantOps(mlir::ModuleOp moduleOp) { SmallVector queryMMAIntrinsics(IREE::HAL::ExecutableVariantOp executableOp) { - llvm::SmallVector mmaIntrinsics; + SmallVector mmaIntrinsics; if (IREE::GPU::TargetAttr target = getGPUTargetAttr(executableOp)) { mmaIntrinsics = llvm::map_to_vector( target.getWgp().getMma(),