From 681a1f73408fd45e33f2f932a608bee668bff905 Mon Sep 17 00:00:00 2001 From: ravil-mobile Date: Thu, 14 Nov 2024 18:35:33 +0000 Subject: [PATCH] [AMD] Added instr.sched guards for the FA-like kernels --- test/TritonGPU/amd/amd-instruction-sched.mlir | 12 +- third_party/amd/backend/compiler.py | 5 +- .../Dialect/TritonAMDGPU/IR/CMakeLists.txt | 2 + .../include/Dialect/TritonAMDGPU/IR/Dialect.h | 1 + .../TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td | 32 +++++ .../TritonAMDGPU/IR/TritonAMDGPUOps.td | 7 +- .../amd/include/TritonAMDGPUToLLVM/Passes.h | 5 +- .../amd/include/TritonAMDGPUToLLVM/Passes.td | 11 +- .../lib/Dialect/TritonAMDGPU/IR/Dialect.cpp | 2 + .../TritonAMDGPUToLLVM/SchedInstructions.cpp | 132 ++++++++++-------- .../amd/lib/TritonAMDGPUToLLVM/Utility.cpp | 23 ++- .../amd/lib/TritonAMDGPUToLLVM/Utility.h | 4 + third_party/amd/python/triton_amd.cc | 10 +- 13 files changed, 164 insertions(+), 82 deletions(-) diff --git a/test/TritonGPU/amd/amd-instruction-sched.mlir b/test/TritonGPU/amd/amd-instruction-sched.mlir index 011116e1b201..f1a3004c5246 100644 --- a/test/TritonGPU/amd/amd-instruction-sched.mlir +++ b/test/TritonGPU/amd/amd-instruction-sched.mlir @@ -1,8 +1,8 @@ -// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -triton-amdgpu-lower-insert-instruction-sched-hints='variant=llvm-iglp-0' -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP0 -// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -triton-amdgpu-lower-insert-instruction-sched-hints='variant=llvm-iglp-1' -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP1 -// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=1' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS1 -// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS2 -// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -triton-amdgpu-lower-insert-instruction-sched-hints='variant=local-prefetch' -debug-only='lower-insert-instruction-sched-hints' -verify-diagnostics 2>&1 | FileCheck %s -check-prefix=USE_LOCAL_PREFETCH_GLOBAL_LOAD +// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints='variant=llvm_iglp_0' -triton-amdgpu-lower-insert-instruction-sched-hints -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP0 +// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints='variant=llvm_iglp_1' -triton-amdgpu-lower-insert-instruction-sched-hints -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP1 +// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=1' -triton-amdgpu-insert-instruction-sched-hints='variant=guard' -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS1 +// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints='variant=guard' -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS2 +// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints='variant=local_prefetch' -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -triton-amdgpu-lower-insert-instruction-sched-hints -debug-only='lower-insert-instruction-sched-hints' -verify-diagnostics 2>&1 | FileCheck %s -check-prefix=USE_LOCAL_PREFETCH_GLOBAL_LOAD // RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=1' | FileCheck %s -check-prefix=LABELING_PS_1 // RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=2' | FileCheck %s -check-prefix=LABELING_PS_2 @@ -69,7 +69,7 @@ module { // INSTR_COUNT_NS2-SAME: numMMAs = #amdgpu.InstCounter<16, tensor<32x32x8xf16>> // USE_LOCAL_PREFETCH_GLOBAL_LOAD: [lower-insert-instruction-sched-hints] - // USE_LOCAL_PREFETCH_GLOBAL_LOAD-SAME: skipping `local-prefetch` scheduling given it needs `buffer_load` instructions + // USE_LOCAL_PREFETCH_GLOBAL_LOAD-SAME: skipping `local_prefetch` scheduling given it needs `buffer_load` instructions. // LABELING_PS_1: scf.for // LABELING_PS_1: %[[REG0_OP0:.+]] = tt.load {{.*}} {OpIdx = #amdgpu.OpIdx<0>} diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index 0433f2458f32..e10dec8ac4f6 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -236,7 +236,7 @@ def make_ttgir(mod, metadata, options): "equivalent behavior in the past.") amd.passes.ttgpuir.add_stream_pipelinev2(pm, options.num_stages) passes.common.add_canonicalizer(pm) - amd.passes.ttgpuir.insert_instruction_sched_hints(pm) + amd.passes.ttgpuir.insert_instruction_sched_hints(pm, options.instruction_sched_variant) passes.ttgpuir.add_optimize_dot_operands(pm, True) passes.ttgpuir.add_remove_layout_conversions(pm) passes.ttgpuir.add_reduce_data_duplication(pm) @@ -294,8 +294,7 @@ def make_llir(src, metadata, options): passes.common.add_canonicalizer(pm) passes.common.add_cse(pm) passes.common.add_symbol_dce(pm) - amd.passes.ttgpuir.lower_instruction_sched_hints(pm, options.arch, options.num_stages, - options.instruction_sched_variant) + amd.passes.ttgpuir.lower_instruction_sched_hints(pm, options.arch, options.num_stages) if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0": passes.llvmir.add_di_scope(pm) amd.passes.ttgpuir.add_builtin_func_to_llvmir(pm, __HIP_FTZ) diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/CMakeLists.txt b/third_party/amd/include/Dialect/TritonAMDGPU/IR/CMakeLists.txt index 25a57075be01..094ecfc7d4e9 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/CMakeLists.txt +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/CMakeLists.txt @@ -11,6 +11,8 @@ add_mlir_doc(TritonAMDGPUOps TritonAMDGPUOps dialects/ -gen-op-doc) add_public_tablegen_target(TritonAMDGPUTableGen) set(LLVM_TARGET_DEFINITIONS TritonAMDGPUAttrDefs.td) +mlir_tablegen(TritonAMDGPUEnums.h.inc -gen-enum-decls) +mlir_tablegen(TritonAMDGPUEnums.cpp.inc -gen-enum-defs) mlir_tablegen(TritonAMDGPUAttrDefs.h.inc -gen-attrdef-decls) mlir_tablegen(TritonAMDGPUAttrDefs.cpp.inc -gen-attrdef-defs) add_public_tablegen_target(TritonAMDGPUAttrDefsIncGen) diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h b/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h index 6dbb0435e20c..c0d9e97556f6 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h @@ -33,6 +33,7 @@ // clang-format off #include "amd/include/Dialect/TritonAMDGPU/IR/Dialect.h.inc" +#include "amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUEnums.h.inc" // clang-format on #define GET_ATTRDEF_CLASSES diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td index c0aa08421bdd..5936ccc78241 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td @@ -26,6 +26,7 @@ include "mlir/IR/AttrTypeBase.td" include "TritonAMDGPUDialect.td" +include "mlir/IR/EnumAttr.td" class TritonAMDGPU_Attr traits = [], string baseCppClass = "::mlir::Attribute"> @@ -59,4 +60,35 @@ def TritonAMDGPU_InstCounter : TritonAMDGPU_Attr<"InstCounter"> { } +class TritonAMDGPU_I32Enum cases> + : I32EnumAttr { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::triton::amdgpu"; +} + +class TritonAMDGPU_I32EnumAttr : + EnumAttr { + let assemblyFormat = "`<` $value `>`"; + let cppNamespace = "::mlir::triton::amdgpu"; +} + +def SchedHintCaseNone : I32EnumAttrCase<"none", 0>; +def SchedHintCaseLLVMIglp0 : I32EnumAttrCase<"llvm_iglp_0", 1>; +def SchedHintCaseLLVMIglp1 : I32EnumAttrCase<"llvm_iglp_1", 2>; +def SchedHintCaseLocalPrefetch : I32EnumAttrCase<"local_prefetch", 3>; +def SchedHintCaseGuard : I32EnumAttrCase<"guard", 4>; + +def TritonAMDGPU_SchedHintsEnum : TritonAMDGPU_I32Enum< + "SchedHint", "Instruction Scheduling Hints for AMD GPUs", [ + SchedHintCaseNone, + SchedHintCaseLLVMIglp0, + SchedHintCaseLLVMIglp1, + SchedHintCaseLocalPrefetch, + SchedHintCaseGuard + ]>; + +def TritonAMDGPU_SchedHintVariantAttr : + TritonAMDGPU_I32EnumAttr<"SchedHintVariant", TritonAMDGPU_SchedHintsEnum>; + + #endif diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td index 68c50d48635b..3c4ccd8db340 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td @@ -58,6 +58,7 @@ def InstructionSchedHint : TT_AMDGPU_Op<"instruction_sched_hint", []> { }]; let arguments = (ins + TritonAMDGPU_SchedHintVariantAttr:$schedVariant, TritonAMDGPU_InstCounter:$numDsReadsA, TritonAMDGPU_InstCounter:$numDsReadsB, TritonAMDGPU_InstCounter:$numDsWritesA, @@ -70,12 +71,12 @@ def InstructionSchedHint : TT_AMDGPU_Op<"instruction_sched_hint", []> { ); let builders = [ - OpBuilder<(ins), [{ + OpBuilder<(ins "SchedHint":$variant), [{ auto ctx = $_state.getContext(); auto noneType = NoneType::get(ctx); auto emptyAttr = amdgpu::InstCounterAttr::get(ctx, 0, noneType); - build($_builder, $_state, emptyAttr, emptyAttr, emptyAttr, emptyAttr, - emptyAttr, emptyAttr, false, false, emptyAttr); + build($_builder, $_state, variant, emptyAttr, emptyAttr, emptyAttr, + emptyAttr, emptyAttr, emptyAttr, false, false, emptyAttr); }]> ]; diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h index f592694295a8..0b1839a3d323 100644 --- a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h +++ b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h @@ -36,11 +36,10 @@ createConvertTritonAMDGPUToLLVMPass(StringRef targetArch, bool ftz); std::unique_ptr> createConvertBuiltinFuncToLLVMPass(bool ftz); std::unique_ptr> -createTritonAMDGPUInsertInstructionSchedHintsPass(); +createTritonAMDGPUInsertInstructionSchedHintsPass(StringRef variant); std::unique_ptr> createTritonAMDGPULowerInstructionSchedHintsPass(StringRef arch, - int32_t numStages, - StringRef variant); + int32_t numStages); #define GEN_PASS_REGISTRATION #include "TritonAMDGPUToLLVM/Passes.h.inc" diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td index 58815daec717..0572d12a928e 100644 --- a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td +++ b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td @@ -61,15 +61,20 @@ def ConvertBuiltinFuncToLLVM : Pass<"convert-builtin-func-to-llvm", "mlir::Modul def TritonAMDGPUInsertInstructionSchedHints : Pass<"triton-amdgpu-insert-instruction-sched-hints", "mlir::ModuleOp"> { let summary = "Insert instruction scheduling hints after the dot ops in the main loop"; - let constructor = "mlir::triton::createTritonAMDGPUInsertInstructionSchedHintsPass()"; + let constructor = "mlir::triton::createTritonAMDGPUInsertInstructionSchedHintsPass(/*variant=*/\"\")"; let dependentDialects = ["mlir::LLVM::LLVMDialect", "mlir::triton::amdgpu::TritonAMDGPUDialect"]; + + let options = [ + Option<"variant", "variant", "std::string", /*default*/"\"\"", + "instruction scheduling variant">, + ]; } def TritonAMDGPULowerInstructionSchedHints : Pass<"triton-amdgpu-lower-insert-instruction-sched-hints", "mlir::ModuleOp"> { let summary = "Lower instruction scheduling hints to LLVM intrinsics"; - let constructor = "mlir::triton::createTritonAMDGPULowerInstructionSchedHintsPass(/*arch=*/\"\",/*numStages=*/2, /*variant=*/\"\")"; + let constructor = "mlir::triton::createTritonAMDGPULowerInstructionSchedHintsPass(/*arch=*/\"\",/*numStages=*/2)"; let dependentDialects = ["mlir::LLVM::LLVMDialect", "mlir::ROCDL::ROCDLDialect", @@ -80,8 +85,6 @@ def TritonAMDGPULowerInstructionSchedHints : Pass<"triton-amdgpu-lower-insert-in "gfx target device architecture, e.g., gfx942">, Option<"numStages", "num_stages", "int32_t", /*default*/"2", "number of pipeline stages">, - Option<"variant", "variant", "std::string", /*default*/"\"none\"", - "instruction scheduling variant">, ]; } diff --git a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp index 1e429fdc39a9..73d5b27f42d5 100644 --- a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp +++ b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp @@ -48,6 +48,8 @@ void mlir::triton::amdgpu::TritonAMDGPUDialect::initialize() { >(); } +#include "Dialect/TritonAMDGPU/IR/TritonAMDGPUEnums.cpp.inc" + #define GET_ATTRDEF_CLASSES #include "Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.cpp.inc" diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp index d93f2ca6c6ec..656b0efe2a17 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp @@ -1,9 +1,11 @@ #include "SchedInstructions.h" #include "TritonAMDGPUToLLVM/Passes.h" #include "TritonAMDGPUToLLVM/TargetUtils.h" +#include "Utility.h" #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Pass/Pass.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" @@ -213,36 +215,11 @@ struct InstructionSchedHintsRewriter : public OpRewritePattern { InstructionSchedHintsRewriter(MLIRContext *ctx, StringRef arch, - int32_t numStages, std::string variant) + int32_t numStages) : OpRewritePattern(ctx), numStages(numStages) { - this->machineDescr = MachineDescr::get(arch); - std::transform(variant.begin(), variant.end(), variant.begin(), - [](unsigned char c) { return std::tolower(c); }); - - this->schedulingType = - llvm::StringSwitch(variant) - .Case("none", SchedulingType::NONE) - .Case("llvm-iglp-0", SchedulingType::LLVM_IGLP_0) - .Case("llvm-iglp-1", SchedulingType::LLVM_IGLP_1) - .Case("local-prefetch", SchedulingType::LOCAL_PREFETCH) - .Default(SchedulingType::UNKNOWN); - - if (this->numStages < 2) { - this->schedulingType = SchedulingType::NONE; - LDBG("ignoring instruction scheduling due to a very low num. " - "stages value. Must be >= 2"); - } } - enum class SchedulingType : uint32_t { - NONE = 0, - LLVM_IGLP_0, - LLVM_IGLP_1, - LOCAL_PREFETCH, - UNKNOWN - }; - // The following is inspired by ROCm Composable Kernel library's V3 pipelining // (see ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp). // This scheduling requires 1x register and 1x LDS buffers combined with the @@ -253,8 +230,8 @@ struct InstructionSchedHintsRewriter if (!(schedHint.getIsBufferLoadsAEnabled() && schedHint.getIsBufferLoadsBEnabled())) { - LDBG("skipping `local-prefetch` scheduling given it needs `buffer_load` " - "instructions"); + LDBG("skipping `local_prefetch` scheduling given it needs `buffer_load` " + "instructions."); return; } @@ -416,15 +393,14 @@ struct InstructionSchedHintsRewriter LogicalResult matchAndRewrite(triton::amdgpu::InstructionSchedHint instructionSchedHint, PatternRewriter &rewriter) const override { - if (this->schedulingType == SchedulingType::NONE) { - rewriter.eraseOp(instructionSchedHint); - return success(); - } - if (this->schedulingType == SchedulingType::UNKNOWN) { - instructionSchedHint.emitError( - "unknown instruction scheduling variant has been provided"); - return failure(); + triton::amdgpu::SchedHint schedulingType = + instructionSchedHint.getSchedVariant(); + if ((this->numStages < 2) && + (schedulingType != triton::amdgpu::SchedHint::guard)) { + LDBG("ignoring instruction scheduling due to a very low num. " + "stages value. Must be >= 2"); + schedulingType = triton::amdgpu::SchedHint::none; } // The switch controls whether instructions are allowed to cross the basic @@ -432,9 +408,9 @@ struct InstructionSchedHintsRewriter // not supposed to be used together with IGLP OPT according to the AMDGPU // backend documentation. const bool limitSchedulingRange = - !(schedulingType == SchedulingType::NONE || - schedulingType == SchedulingType::LLVM_IGLP_0 || - schedulingType == SchedulingType::LLVM_IGLP_1); + !(schedulingType == triton::amdgpu::SchedHint::none || + schedulingType == triton::amdgpu::SchedHint::llvm_iglp_0 || + schedulingType == triton::amdgpu::SchedHint::llvm_iglp_1); Location loc = instructionSchedHint->getLoc(); Block *block = instructionSchedHint->getBlock(); if (limitSchedulingRange) { @@ -446,17 +422,24 @@ struct InstructionSchedHintsRewriter rewriter.setInsertionPoint(block, std::prev(block->end())); switch (schedulingType) { - case SchedulingType::LLVM_IGLP_0: - case SchedulingType::LLVM_IGLP_1: + case triton::amdgpu::SchedHint::llvm_iglp_0: + [[fallthrough]]; + case triton::amdgpu::SchedHint::llvm_iglp_1: { createIglpOpt(rewriter, loc, static_cast(schedulingType) - 1); break; - case SchedulingType::LOCAL_PREFETCH: + } + case triton::amdgpu::SchedHint::local_prefetch: { createLocalPrefetchSchedule(rewriter, loc, instructionSchedHint); break; - case SchedulingType::NONE: - default: + } + case triton::amdgpu::SchedHint::guard: + [[fallthrough]]; + case triton::amdgpu::SchedHint::none: + [[fallthrough]]; + default: { break; } + } if (limitSchedulingRange) createSchedBarrier(rewriter, loc, @@ -468,7 +451,6 @@ struct InstructionSchedHintsRewriter private: int32_t numStages; - SchedulingType schedulingType; std::unique_ptr machineDescr; }; @@ -477,11 +459,9 @@ struct TritonAMDGPULowerInstructionSchedHints TritonAMDGPULowerInstructionSchedHints> { explicit TritonAMDGPULowerInstructionSchedHints(StringRef arch, - int32_t numStages, - StringRef variant) { + int32_t numStages) { this->arch = std::move(arch.str()); this->numStages = numStages; - this->variant = std::move(variant.str()); } void runOnOperation() override { @@ -498,7 +478,7 @@ struct TritonAMDGPULowerInstructionSchedHints RewritePatternSet patterns(ctx); patterns.add(ctx, this->arch, - this->numStages, this->variant); + this->numStages); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { @@ -512,18 +492,57 @@ struct TritonAMDGPUInsertInstructionSchedHints : public triton::impl::TritonAMDGPUInsertInstructionSchedHintsBase< TritonAMDGPUInsertInstructionSchedHints> { + explicit TritonAMDGPUInsertInstructionSchedHints(StringRef variant) { + this->variant = std::move(variant.str()); + } + + void guardFlashAttentionLikeProblems(triton::FuncOp funcOp) { + llvm::SetVector innermostForOps = + triton::AMD::getAllInnerForOps(funcOp); + for (auto forOp : innermostForOps) { + size_t gemmCounter = 0; + size_t reduceCounter = 0; + bool hasExpOp = false; + forOp->walk([&](triton::DotOp) { ++gemmCounter; }); + forOp->walk([&](triton::ReduceOp) { ++reduceCounter; }); + forOp->walk([&](math::Exp2Op) { hasExpOp = true; }); + if ((gemmCounter > 1) && (reduceCounter > 1) && hasExpOp) { + OpBuilder builder(forOp->getContext()); + Block *block = forOp.getBody(); + builder.setInsertionPoint(block, std::prev(block->end())); + builder.create( + forOp->getLoc(), triton::amdgpu::SchedHint::guard); + } + } + } + void runOnOperation() override { MLIRContext *ctx = &getContext(); ModuleOp mod = getOperation(); - mod.walk([this, ctx](scf::ForOp forOp) { + auto maybeSchedHint = triton::amdgpu::symbolizeSchedHint(this->variant); + if (!maybeSchedHint) { + LDBG("Skipping instruction scheduling: unknown scheduling hint."); + return; + } + + mod.walk([this](triton::FuncOp funcOp) { + guardFlashAttentionLikeProblems(funcOp); + }); + + triton::amdgpu::SchedHint schedHint = maybeSchedHint.value(); + if (schedHint == triton::amdgpu::SchedHint::none) + return; + + mod.walk([this, ctx, schedHint](scf::ForOp forOp) { // Note, instruction schedule barriers are inserted only in the case of // a single `tt.dot` op in a `scf::ForOp` scope in the current // implementation. if (auto dotOp = getSingleDotOpIfExists(forOp)) { OpBuilder rewriter(ctx); rewriter.setInsertionPointAfter(dotOp); - rewriter.create(dotOp->getLoc()); + rewriter.create(dotOp->getLoc(), + schedHint); } }); } @@ -533,14 +552,13 @@ struct TritonAMDGPUInsertInstructionSchedHints namespace mlir::triton { std::unique_ptr> createTritonAMDGPULowerInstructionSchedHintsPass(StringRef arch, - int32_t numStages, - StringRef variant) { - return std::make_unique( - arch, numStages, variant); + int32_t numStages) { + return std::make_unique(arch, + numStages); } std::unique_ptr> -createTritonAMDGPUInsertInstructionSchedHintsPass() { - return std::make_unique(); +createTritonAMDGPUInsertInstructionSchedHintsPass(StringRef variant) { + return std::make_unique(variant); } } // namespace mlir::triton diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp index 0bd401f1993a..bc7c130aea59 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp @@ -357,5 +357,26 @@ void llStore(RewriterBase &rewriter, Location loc, Value ptr, Value val, appendOrGetExternFuncOp(rewriter, parent, funcName, funcType); LLVM::createLLVMCallOp(rewriter, loc, funcOp, ValueRange({ptr, val, pred})); } - } // namespace mlir::LLVM::AMD + +namespace { +void getAllInnerForOps(scf::ForOp forOp, + llvm::SetVector &innermostForOps) { + bool found = false; + forOp.getBody()->walk([&found, &innermostForOps](scf::ForOp innerForOp) { + getAllInnerForOps(innerForOp, innermostForOps); + found = true; + }); + if (!found) + innermostForOps.insert(forOp); +} +} // namespace + +namespace mlir::triton::AMD { +llvm::SetVector getAllInnerForOps(mlir::triton::FuncOp funcOp) { + llvm::SetVector innermostForOps{}; + funcOp->walk( + [&](scf::ForOp forOp) { ::getAllInnerForOps(forOp, innermostForOps); }); + return innermostForOps; +} +} // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h index d150531848e3..74ed4c56dd25 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h @@ -49,4 +49,8 @@ void llStore(RewriterBase &rewriter, Location loc, Value ptr, Value val, triton::CacheModifier cm = triton::CacheModifier::NONE); } // namespace mlir::LLVM::AMD +namespace mlir::triton::AMD { +llvm::SetVector getAllInnerForOps(mlir::triton::FuncOp funcOp); +} // namespace mlir::triton::AMD + #endif diff --git a/third_party/amd/python/triton_amd.cc b/third_party/amd/python/triton_amd.cc index 9eab3771263e..5167540b71f5 100644 --- a/third_party/amd/python/triton_amd.cc +++ b/third_party/amd/python/triton_amd.cc @@ -44,14 +44,14 @@ void init_triton_amd_passes_ttgpuir(py::module &&m) { m.def("add_builtin_func_to_llvmir", [](mlir::PassManager &pm, bool ftz) { pm.addPass(createConvertBuiltinFuncToLLVMPass(ftz)); }); - m.def("insert_instruction_sched_hints", [](mlir::PassManager &pm) { - pm.addPass(createTritonAMDGPUInsertInstructionSchedHintsPass()); + m.def("insert_instruction_sched_hints", [](mlir::PassManager &pm, + const std::string &variant) { + pm.addPass(createTritonAMDGPUInsertInstructionSchedHintsPass(variant)); }); m.def("lower_instruction_sched_hints", - [](mlir::PassManager &pm, const std::string &arch, int32_t numStages, - const std::string &variant) { + [](mlir::PassManager &pm, const std::string &arch, int32_t numStages) { pm.addPass(createTritonAMDGPULowerInstructionSchedHintsPass( - arch, numStages, variant)); + arch, numStages)); }); m.def("add_decompose_unsupported_conversions", [](mlir::PassManager &pm, const std::string &arch) {