Skip to content

Commit

Permalink
[AMD] reverted changes in the language front-end
Browse files Browse the repository at this point in the history
  • Loading branch information
ravil-mobile committed Nov 21, 2024
1 parent 61fba4f commit 661a12b
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 45 deletions.
3 changes: 0 additions & 3 deletions python/triton/compiler/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,7 +930,6 @@ def visit_For(self, node):
return
num_stages = None
loop_unroll_factor = None
use_instructions_sched_guard = False
if IteratorClass is language.range:
iterator = IteratorClass(*iter_args, **iter_kwargs)
# visit iterator arguments
Expand All @@ -941,7 +940,6 @@ def visit_For(self, node):
step = iterator.step
num_stages = iterator.num_stages
loop_unroll_factor = iterator.loop_unroll_factor
use_instructions_sched_guard = iterator.use_instructions_guard
elif IteratorClass is range:
# visit iterator arguments
# note: only `range` iterator is supported now
Expand Down Expand Up @@ -1016,7 +1014,6 @@ def visit_For(self, node):
for_op.set_attr("tt.num_stages", self.builder.get_int32_attr(num_stages))
if loop_unroll_factor is not None:
for_op.set_attr("tt.loop_unroll_factor", self.builder.get_int32_attr(loop_unroll_factor))
for_op.set_attr("tt.use_instructions_sched_guard", self.builder.get_bool_attr(use_instructions_sched_guard))

self.scf_stack.append(node)
for_op_body = for_op.get_body(0)
Expand Down
4 changes: 1 addition & 3 deletions python/triton/language/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2730,8 +2730,7 @@ def kernel(...):
this value implies no unrolling.
"""

def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None,
use_instructions_guard=False):
def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None):
if step is None:
self.step = constexpr(1)
else:
Expand All @@ -2744,7 +2743,6 @@ def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_fact
self.end = arg2
self.num_stages = num_stages
self.loop_unroll_factor = loop_unroll_factor
self.use_instructions_guard = use_instructions_guard

def __iter__(self):
raise RuntimeError("tl.range can only be used in @triton.jit'd functions")
Expand Down
4 changes: 2 additions & 2 deletions third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,8 @@ def make_ttgir(mod, metadata, options):
stream_prefetch = os.getenv("TRITON_HIP_STREAM_PREFETCH", "0") == "1"
use_buffer_ops = os.environ.get("AMDGCN_USE_BUFFER_OPS", "0") == "1"

# The `local-prefetch` scheduling variant requires turning on buffer ops.
if options.instruction_sched_variant == "local-prefetch":
# The `local_prefetch` scheduling variant requires turning on buffer ops.
if options.instruction_sched_variant == "local_prefetch":
stream_prefetch = use_buffer_ops = True

if amd.has_matrix_core_feature(options.arch):
Expand Down
66 changes: 31 additions & 35 deletions third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,20 +225,20 @@ struct InstructionSchedHintsRewriter
// (see ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp).
// This scheduling requires 1x register and 1x LDS buffers combined with the
// local (LDS to registers) and global (HBM to registers) data prefetching.
void createLocalPrefetchSchedule(
LogicalResult createLocalPrefetchSchedule(
PatternRewriter &rewriter, Location loc,
triton::amdgpu::InstructionSchedHint schedHint) const {

if (!(schedHint.getIsBufferLoadsAEnabled() &&
schedHint.getIsBufferLoadsBEnabled())) {
LDBG("skipping `local_prefetch` scheduling given it needs `buffer_load` "
"instructions.");
return;
return failure();
}

if (!machineDescr) {
schedHint.emitError("unknown target architecture detected");
return;
return failure();
}

const uint32_t numDsReadInstA = schedHint.getNumDsReadsA().getValue();
Expand All @@ -254,12 +254,12 @@ struct InstructionSchedHintsRewriter

if (numBufferLoadInstA == 0) {
schedHint.emitError("buffer load count for tile A must be initialized");
return;
return failure();
}

if (numBufferLoadInstB == 0) {
schedHint.emitError("buffer load count for tile B must be initialized");
return;
return failure();
}

const uint32_t numMmaInst = schedHint.getNumMMAs().getValue();
Expand All @@ -268,7 +268,7 @@ struct InstructionSchedHintsRewriter
auto maybeMmaExecCycle = machineDescr->getMmaExecCycle(mmaType.getShape());
if (llvm::failed(maybeMmaExecCycle)) {
schedHint.emitError("unknown mma instruction type");
return;
return failure();
}
const uint32_t mmaExecCycle = maybeMmaExecCycle.value();

Expand Down Expand Up @@ -389,6 +389,7 @@ struct InstructionSchedHintsRewriter
targetFeatures.push_back(str_attr("-load-store-opt"));
funcOp.setTargetFeaturesAttr(
::mlir::LLVM::TargetFeaturesAttr::get(ctx, targetFeatures));
return success();
}

LogicalResult
Expand All @@ -404,6 +405,13 @@ struct InstructionSchedHintsRewriter
schedulingType = triton::amdgpu::SchedHint::none;
}

// The switch controls whether instructions are allowed to cross the basic
// block boundaries at the very top and at the very bottom. Note, this is
// not supposed to be used together with IGLP OPT according to the AMDGPU
// backend documentation.
bool limitSchedulingRange =
schedulingType == triton::amdgpu::SchedHint::local_prefetch;

Location loc = instructionSchedHint->getLoc();
Block *block = instructionSchedHint->getBlock();
rewriter.setInsertionPoint(block, std::prev(block->end()));
Expand All @@ -416,21 +424,17 @@ struct InstructionSchedHintsRewriter
break;
}
case triton::amdgpu::SchedHint::local_prefetch: {
createLocalPrefetchSchedule(rewriter, loc, instructionSchedHint);
LogicalResult result =
createLocalPrefetchSchedule(rewriter, loc, instructionSchedHint);
if (failed(result))
limitSchedulingRange = false;
break;
}
default: {
break;
}
}

// The switch controls whether instructions are allowed to cross the basic
// block boundaries at the very top and at the very bottom. Note, this is
// not supposed to be used together with IGLP OPT according to the AMDGPU
// backend documentation.
const bool limitSchedulingRange =
schedulingType == triton::amdgpu::SchedHint::local_prefetch;

auto scanResult = block->walk([](triton::amdgpu::InstructionSchedGuard) {
return WalkResult::interrupt();
});
Expand Down Expand Up @@ -497,27 +501,19 @@ struct TritonAMDGPUInsertInstructionControlLogic
MLIRContext *ctx = &getContext();
ModuleOp mod = getOperation();

mod->walk([&](scf::ForOp forOp) {
// use the global kernel parameter to decide whether instruction guards
// need to be inserted
bool insertInstructionGuards = this->useInstructionSchedGuards;

// use a region local parameter to decide whether instruction guards need
// to be inserted Note, the local parameter has higher precedency of the
// global one
static const std::string localInstructionSchedGuardAttrName =
"tt.use_instructions_sched_guard";
if (auto localInstructionGuardingInfo = forOp->getAttrOfType<BoolAttr>(
localInstructionSchedGuardAttrName)) {
insertInstructionGuards = localInstructionGuardingInfo.getValue();
}
if (insertInstructionGuards) {
OpBuilder builder(forOp->getContext());
Block *block = forOp.getBody();
builder.setInsertionPoint(block, std::prev(block->end()));
builder.create<triton::amdgpu::InstructionSchedGuard>(forOp.getLoc());
}
});
// use the global kernel parameter to decide whether to insert instruction
// scheduling guards
if (this->useInstructionSchedGuards) {
mod.walk([&](triton::FuncOp funcOp) {
SmallVector<scf::ForOp> leafForOps = AMD::getLeafForOps(funcOp);
for (auto forOp : leafForOps) {
OpBuilder builder(forOp->getContext());
Block *block = forOp.getBody();
builder.setInsertionPoint(block, std::prev(block->end()));
builder.create<triton::amdgpu::InstructionSchedGuard>(forOp.getLoc());
}
});
}

std::string allSchedVariants;
llvm::raw_string_ostream os(allSchedVariants);
Expand Down
5 changes: 3 additions & 2 deletions third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,8 +366,9 @@ SmallVector<scf::ForOp> getLeafForOps(triton::FuncOp funcOp) {

SmallVector<scf::ForOp> leafOps;
for (scf::ForOp forOp : allOps) {
auto r = forOp->walk([](scf::ForOp) { return WalkResult::interrupt(); });
if (!r.wasInterrupted())
auto searchResult =
forOp->walk([](scf::ForOp) { return WalkResult::interrupt(); });
if (searchResult.wasInterrupted())
leafOps.push_back(forOp);
}
return leafOps;
Expand Down

0 comments on commit 661a12b

Please sign in to comment.