Skip to content

Commit

Permalink
Revert "fix issues after LLVM update (#22)"
Browse files Browse the repository at this point in the history
This reverts commit 450e6be.
  • Loading branch information
nhat-nguyen committed Jan 4, 2024
1 parent c2fb18d commit 20ad94e
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 16 deletions.
2 changes: 0 additions & 2 deletions include/triton-shared/Analysis/UseAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,6 @@ class UseAnalysis : public dataflow::SparseBackwardDataFlowAnalysis<UseInfo> {

void visitBranchOperand(OpOperand &operand) override { return; }

void visitCallOperand(OpOperand &operand) override { return; }

void setToExitState(UseInfo *lattice) override {
lattice->type = UseType::Undefined;
}
Expand Down
4 changes: 2 additions & 2 deletions lib/Analysis/PtrAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1231,7 +1231,7 @@ void PtrAnalysis::rewriteForOp(
mapping.map(op.getInitArgs(), newInitArgs);
mapping.map(op.getRegionIterArgs(), args);

for (auto &bodyOp : op.getRegion().getOps()) {
for (auto &bodyOp : op.getLoopBody().getOps()) {
b.clone(bodyOp, mapping);
}

Expand Down Expand Up @@ -1309,7 +1309,7 @@ void PtrAnalysis::rewriteForOp(

// Update the loop body. Manually invoke the rewrite logic on addptr and yield
// in the loop body, so we can take advantage of the states we built up
for (auto &bodyOp : newOp.getRegion().getOps()) {
for (auto &bodyOp : newOp.getLoopBody().getOps()) {
if (auto addptrOp = dyn_cast<triton::AddPtrOp>(bodyOp)) {
rewriteAddptrOp(addptrOp, rewriter, knownPtrs);
} else if (auto advanceOp = dyn_cast<triton::AdvanceOp>(bodyOp)) {
Expand Down
17 changes: 7 additions & 10 deletions lib/Conversion/TritonToLinalg/TritonToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -838,8 +838,8 @@ struct ReduceConverter : public OpConversionPattern<triton::ReduceOp> {
}

bool isReductionOpSupported(Operation *redOp) const {
return isa<arith::AddFOp, arith::AddIOp, arith::MaximumFOp, arith::MinSIOp,
arith::MinUIOp, arith::MaxSIOp, arith::MaxUIOp>(redOp);
return isa<arith::AddFOp, arith::MaximumFOp, arith::MinSIOp, arith::MinUIOp,
arith::MaxSIOp, arith::MaxUIOp>(redOp);
}

arith::ConstantOp getRedBaseConstOp(ConversionPatternRewriter &rewriter,
Expand All @@ -852,9 +852,6 @@ struct ReduceConverter : public OpConversionPattern<triton::ReduceOp> {
.Case([&](arith::AddFOp) {
return rewriter.getFloatAttr(constantType, 0.f);
})
.Case([&](arith::AddIOp) {
return rewriter.getIntegerAttr(constantType, 0);
})
.Case([&](arith::MaximumFOp) {
return rewriter.getFloatAttr(
constantType, -std::numeric_limits<float>::infinity());
Expand Down Expand Up @@ -902,8 +899,8 @@ struct ReduceConverter : public OpConversionPattern<triton::ReduceOp> {
}
return b.create<arith::AddFOp>(loc, lhs, rhs);
})
.Case<arith::AddIOp, arith::MaximumFOp, arith::MinSIOp, arith::MinUIOp,
arith::MaxSIOp, arith::MaxUIOp>([&](auto redOp) {
.Case<arith::MaxFOp, arith::MinSIOp, arith::MinUIOp, arith::MaxSIOp,
arith::MaxUIOp>([&](auto redOp) {
return b.create<decltype(redOp)>(loc, lhs, rhs);
})
.Default([](Operation *op) {
Expand Down Expand Up @@ -1158,12 +1155,12 @@ struct MinMaxConverter : public OpRewritePattern<CmpOp> {
arith::CmpFPredicate pred) const {
switch (pred) {
case arith::CmpFPredicate::OGT:
rewriter.replaceOpWithNewOp<arith::MaximumFOp>(selectOp, cmpOp.getLhs(),
cmpOp.getRhs());
rewriter.replaceOpWithNewOp<arith::MaxFOp>(selectOp, cmpOp.getLhs(),
cmpOp.getRhs());
break;
case arith::CmpFPredicate::OLT:
rewriter.replaceOpWithNewOp<arith::MinimumFOp>(selectOp, cmpOp.getLhs(),
cmpOp.getRhs());
cmpOp.getRhs());
break;
default:
llvm_unreachable("Unhandled predicate");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ module {
// CHECK: [[VAR_inserted_:%.+]] = tensor.insert [[CST_0_]] into [[VAR_6_]][] : tensor<f32>
// CHECK: [[VAR_reduced_:%.+]] = linalg.reduce ins([[VAR_5_]] : tensor<128xf32>) outs([[VAR_inserted_]] : tensor<f32>) dimensions = [0]
// CHECK: ([[in_1:%.+]]: f32, [[init_1:%.+]]: f32) {
// CHECK: [[VAR_19_:%.+]] = arith.maximumf [[in_1]], [[init_1]] : f32
// CHECK: [[VAR_19_:%.+]] = arith.maxf [[in_1]], [[init_1]] : f32
// CHECK: linalg.yield [[VAR_19_]] : f32
// CHECK: }
// CHECK-DAG: [[VAR_extracted_:%.+]] = tensor.extract [[VAR_reduced_]][] : tensor<f32>
Expand Down
2 changes: 1 addition & 1 deletion test/Conversion/TritonToLinalg/reducemax_32_256_bf16.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ module {
// CHECK: %[[VAL_11:.*]] = linalg.fill ins(%[[VAL_6]] : bf16) outs(%[[VAL_10]] : tensor<256x16xbf16>) -> tensor<256x16xbf16>
// CHECK: %[[VAL_12:.*]] = linalg.reduce ins(%[[VAL_9]] : tensor<32x256x16xbf16>) outs(%[[VAL_11]] : tensor<256x16xbf16>) dimensions = [0]
// CHECK: (%[[VAL_13:.*]]: bf16, %[[VAL_14:.*]]: bf16) {
// CHECK: %[[VAL_15:.*]] = arith.maximumf %[[VAL_13]], %[[VAL_14]] : bf16
// CHECK: %[[VAL_15:.*]] = arith.maxf %[[VAL_13]], %[[VAL_14]] : bf16
// CHECK: linalg.yield %[[VAL_15]] : bf16
// CHECK: }
// CHECK: memref.tensor_store %[[VAL_12]], %[[VAL_1]] : memref<256x16xbf16>
Expand Down

0 comments on commit 20ad94e

Please sign in to comment.