Skip to content

Commit

Permalink
Set PtrState scalar to 0 for pointers from kernel arguments (#189)
Browse files Browse the repository at this point in the history
When rewriting `tts.get_structured_state` operating on scalar pointers,
we need to use the offset from `state.scalar`. Pointers directly from
the kernel arguments unfortunately is missing the `scalar` field,
causing a crash when creating offset from this `scalar` field.
  • Loading branch information
nhat-nguyen authored Nov 6, 2024
1 parent d8c8f29 commit 3fe82cb
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions lib/AnalysisStructured/PtrAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1055,13 +1055,21 @@ PtrAnalysis::rewriteGetStructuredStateOp(tts::GetStructuredStateOp op) {
ptrMap.contains(tritonValue) ? ptrMap.lookup(tritonValue) : tritonValue;

SmallVector<Value> replacements{remappedValue};
OpBuilder builder(op);

if (state.getRank() == 0) {
// For scalar pointers, the scalar contains the offset and is the only
// relevant state that could be updated by the loop.
replacements.push_back(state.scalar);
if (state.scalar) {
replacements.push_back(state.scalar);
} else {
// This operand is a pointer directly from the kernel arguments.
// Use offset 0.
assert(!tritonValue.getDefiningOp());
replacements.push_back(builder.create<arith::ConstantOp>(
op.getLoc(), builder.getIndexAttr(0)));
}
} else {
OpBuilder builder(op);
for (auto [j, s] : llvm::enumerate(state.offsets)) {
auto sIntAttr = getIntAttr(s);
if (sIntAttr) {
Expand Down

0 comments on commit 3fe82cb

Please sign in to comment.