-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AMD] Add instruction schedule loop boundary guard hints #5163
base: main
Are you sure you want to change the base?
Conversation
2f4d48f
to
681a1f7
Compare
@sjw, @antiagainst @zhanglx13, could you, please, review the code? |
third_party/amd/backend/compiler.py
Outdated
@@ -64,7 +64,7 @@ class HIPOptions: | |||
# Kernel library. Note, this variant requires the use of buffer load/store ops | |||
# and a special software pipelining style - i.e., 1x LDS and 1x register | |||
# prefetch buffers for each GEMM tile. | |||
instruction_sched_variant: str = 'none' | |||
instruction_sched_variant: str = 'guard' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you update the comment in the above to:
- Move the "none" variant from L51 to L57
- Add a new bullet for "guard". Also "guard" is too generic; what about naming it as "loop_boundary_guard" or something.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move the "none" variant from L51 to L57
Done
Add a new bullet for "guard". Also "guard" is too generic; what about naming it as "loop_boundary_guard" or something.
Moved to a dedicated op and pass
case SchedulingType::LLVM_IGLP_0: | ||
case SchedulingType::LLVM_IGLP_1: | ||
case triton::amdgpu::SchedHint::llvm_iglp_0: | ||
[[fallthrough]]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure what value these [[fallthrough]]
marker provides.. It's pretty clear for C/C++ folks it should be fall through. Having it just takes up multiple lines of code for maintenance burden. Can you drop them please?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[[fallthrough]]
attribute explicitly indicates intended fallthrough for better code clarity and to suppress warnings from compilers. Are you sure you want to drop it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
triton::amdgpu::SchedHint schedulingType = | ||
instructionSchedHint.getSchedVariant(); | ||
if ((this->numStages < 2) && | ||
(schedulingType != triton::amdgpu::SchedHint::guard)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this the case for all other variants? (at least I see we don't need to emit the debug in the below for none
case)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SchedHint::guard
has been removed from this pass
mod.walk([this, ctx](scf::ForOp forOp) { | ||
auto maybeSchedHint = triton::amdgpu::symbolizeSchedHint(this->variant); | ||
if (!maybeSchedHint) { | ||
LDBG("Skipping instruction scheduling: unknown scheduling hint."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you also print out the provided hint variant? For error message, be explicit and verbose--it could save you some debugging time down the road. :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
||
namespace { | ||
void getAllInnerForOps(scf::ForOp forOp, | ||
llvm::SetVector<scf::ForOp> &innermostForOps) { | ||
bool found = false; | ||
forOp.getBody()->walk([&found, &innermostForOps](scf::ForOp innerForOp) { | ||
getAllInnerForOps(innerForOp, innermostForOps); | ||
found = true; | ||
}); | ||
if (!found) | ||
innermostForOps.insert(forOp); | ||
} | ||
} // namespace | ||
|
||
namespace mlir::triton::AMD { | ||
llvm::SetVector<scf::ForOp> getAllInnerForOps(mlir::triton::FuncOp funcOp) { | ||
llvm::SetVector<scf::ForOp> innermostForOps{}; | ||
funcOp->walk( | ||
[&](scf::ForOp forOp) { ::getAllInnerForOps(forOp, innermostForOps); }); | ||
return innermostForOps; | ||
} | ||
} // namespace mlir::triton::AMD |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This implementation is more complicated than I'd expect. I think something like the following would do?
SmallVector<scf::ForOp> getLeafForOps(triton::FuncOp funcOp) {
SmallVector<scf::ForOp> allOps;
funcOp->walk([&](scf::ForOp forOp) { allOps.push_back(forOp); });
SmallVector<scf::ForOp> leafOps;
for (scf::ForOp forOp : allOps) {
auto r = forOp->walk([](scf::ForOp) { return WalkResult::interrupt(); });
if (!r.wasInterrupted()) leafOps.push_back(forOp);
}
return leafOps;
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@antiagainst, totally agree with your implementation. It is more concise and it is doing the same. Thanks!
Done
return; | ||
} | ||
|
||
mod.walk([this](triton::FuncOp funcOp) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm confused why we run this together with the logic at L537. They are disjoint effectively. Can you organize this together with L537 better?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instruction Scheduling Hint was design for a single tt.DotOp
inside a scf.ForOp
. It is complicated right now to extend it to support multiple tt.DotOp
in a single region. There may be cases where a single tt.Load
refers to multiple tt.DotOps
(the same may be regarding ttg.LocalLoad
and ttg.LocalStore
). In these cases, it is not 100% clear to which hint
the corresponding ds_reads/ds_writes/global_loads/buffer_loads
need to be attached for interleaving. Probably, ds_reads/global_loads/buffer_loads
need to be attached to the first hint (which refers to the first tt.DotOp
in the region) and ds_writes
to the last one. However, there may be other computations in between any 2 tt.DotOp
which may also involve load/store operation - e.g., tt.Load
. As @sjw36 mentioned in our chat, we probably need to move to a proper DAG approach.
I'd like to point out that I though we urgently need guard
option for FA-like kernels; this is the main goal of this PR.
this->variant = std::move(variant.str()); | ||
} | ||
|
||
void guardFlashAttentionLikeProblems(triton::FuncOp funcOp) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This disregards the developer choice and forcefully set for attention. Why is that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
guardFlashAttentionLikeProblems
has been removed in the new implementation
LDBG("skipping `local-prefetch` scheduling given it needs `buffer_load` " | ||
"instructions"); | ||
LDBG("skipping `local_prefetch` scheduling given it needs `buffer_load` " | ||
"instructions."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I always want to ask, why the local_prefetch
scheduling needs a specific type of instruction? In this particular case, why global_load does not work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
global_loads
involve address calculations (VALU instructions) and it has loop-carried dependencies. It turns out this messes up scheduling imposed by LLVM intrinsic calls - i.e., the compiler backend gets confused. In short, the CK-like scheduling works only with buffer_load
instructions. Based on my experience, it is not worth it to apply local_prefetch
to a GEMM variant which involve global_loads
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so buffer_load does address update with SALU instructions?
the compiler backend gets confused
How confused? Does it place those VALU's in the wrong place so that their latencies are exposed?
The PR title is misleading. We don't need anything special for flash-attention like kernels.
|
I'd like to propose something different - i.e., a new instruction in our dialect which is dedicated to instruction scheduling guards (
|
9568f9b
to
34c74ff
Compare
4fa30da
to
661a12b
Compare
# Kernel library. Note, this variant requires the use of buffer load/store ops | ||
# and a special software pipelining style - i.e., 1x LDS and 1x register | ||
# prefetch buffers for each GEMM tile. | ||
instruction_sched_variant: str = 'none' | ||
|
||
# The following option prevents moves of instructions from the regions where they are defined. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I'm not sure we want to introduce yet another separate knob here. The combinational effect is not something I'm convinced that we need right now. Why can't this be just another option like the above? Even this is experimental right now, I'd prefer to trim down the amount of different variants to only leave proven useful ones, not a combational of lots different configurations going unbound.
If you need to turn on the boundary guard for some existing variants, it can be achieved with patterns. For example, you can have two patterns, one LowerBoundaryGuardToSchedBarrier
, and other RemoveBoundaryGuard
. If the variant is local_prefetch
/etc., then you pull in the first pattern; otherwise pull in the second. This keeps the lowering patterns separate concerns while making them composable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @antiagainst, I disagree a bit. The point is that InstructionSchedHint
and InstructionSchedGuard
have different application scopes. The former can be applied when there is a single tt.DotOp
in a region and, thus, thus it is very specific (see the definition). The latter can be applied to any region regardless what operations are inside (e.g., the FA kernel). We are going to get some ambiguous solution If we combine these two ops together. Additionally, the logic is going to complicated and difficult to read.
The current solution provide some separation of concerns which, in my opinion, is going from the SW point of view. Did I managed to change your opinion?
def TritonAMDGPUInsertInstructionSchedHints : Pass<"triton-amdgpu-insert-instruction-sched-hints", "mlir::ModuleOp"> { | ||
let summary = "Insert instruction scheduling hints after the dot ops in the main loop"; | ||
let constructor = "mlir::triton::createTritonAMDGPUInsertInstructionSchedHintsPass()"; | ||
def TritonAMDGPUInsertInstructionControlLogic : Pass<"triton-amdgpu-insert-instruction-control-logic", "mlir::ModuleOp"> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Following the previous comments of avoiding combinational knobs, here I'm not sure we want a proliferation of passses that just handle one op. They should be folded into existing passes given they are functionality closely related. You can use greedy pattern rewriter for simple 1:1 op rewrites.
@@ -491,14 +469,15 @@ struct TritonAMDGPULowerInstructionSchedHints | |||
ConversionTarget target(*ctx); | |||
target.addLegalDialect<LLVM::LLVMDialect>(); | |||
target.addIllegalOp<triton::amdgpu::InstructionSchedHint>(); | |||
target.addLegalOp<triton::amdgpu::InstructionSchedGuard>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this whole pass should be restructured. What about something like this: define a few patterns for lowering sched hint op, LowerSchedHintIntoIGLP
, LowerSchedHintIntoSchedGroupBarrier
, RemoveSchedHint
, etc. And also as I commented before, LowerBoundaryGuardIntoSchedBarrier
, RemoveBoundaryGuard
. Each pattern does no excessive checks--they just do mechanical op conversion as the name indicates. In this runOnOperation
, check the sched variant chosen and pull in patterns accordingly to run greey pattern rewriter. This gives us a clear structure--the switch/conditions are in the pass, with each pattern dedicated to its own task. Easy to grow/shrink to avoid massive/branchy patterns.
661a12b
to
f62acd8
Compare
Extended AMDGPU instruction scheduling.
The introduced source code changes add
sched.barriers
at the beginning and at the end of eachscf.For
op (calledguards
). The guards prevent moves of instructions from basic block adjacent to the bodies forfor-loops
. According to test results, it results in increase performance for the FA-like kernels due to a reduction of VGPRs spilling.I am not making a trivial change, such as fixing a typo in a comment.
I have written a PR description following these
rules.
I have run
pre-commit run --from-ref origin/main --to-ref HEAD
.Select one of the following.
/test
forlit
tests/unittest
for C++ tests/python/test
for end-to-end testsSelect one of the following.
lit
tests.lit
tests I have added follow these best practices,including the "tests should be minimal" section. (Usually running Python code
and using the instructions it generates is not minimal.)