Skip to content

Commit

Permalink
Remove unused code in PtrAnalysis
Browse files Browse the repository at this point in the history
  • Loading branch information
nhat-nguyen authored Nov 20, 2023
1 parent ea77225 commit 64bb6c5
Showing 1 changed file with 22 additions and 63 deletions.
85 changes: 22 additions & 63 deletions lib/Analysis/PtrAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ namespace mlir {

namespace triton {

static void assertValidUnrealizedCast(UnrealizedConversionCastOp op) {
assert(op && op->hasAttr(ModuloState::WraparoundAttr) &&
op.getInputs().size() == 3 &&
op.getInputs()[0].getDefiningOp<memref::ReinterpretCastOp>() &&
op.getInputs()[1].getDefiningOp<memref::ReinterpretCastOp>() &&
op.getInputs()[2].getDefiningOp<triton::AddPtrOp>());
}

MemRefType PtrState::getResultMemrefType(MLIRContext *context, int64_t offset,
ArrayRef<int64_t> resultShape) const {

Expand Down Expand Up @@ -371,7 +379,6 @@ void PtrAnalysis::visitOperandRem(
PtrState rhsState;
visitOperand(remOp.getRhs(), rhsState, loc, rewriter, knownPtrs);
assert(rhsState.scalar);
rhsState.scalar.dump();

state.modulos.back() = ModuloState(rhsState.scalar, rewriter.getIndexAttr(0));
}
Expand Down Expand Up @@ -860,37 +867,18 @@ void PtrAnalysis::rewriteYieldOp(
for (auto [i, v] : llvm::enumerate(adaptor.getOperands())) {
if (auto mappedV = rewriter.getRemappedValue(v)) {
// If this value is a tensor of pointers produced by AddPtrOp,
// TritonTypeConverter should have converted to MemRefType without
// layout information. Since it doesn't match with the MemRefType
// that we produced in rewriteAddptrOp (which is in canonical form
// with layout information), an unrealized_conversion_cast should
// have been added. We need to trace it back through this
// unrealized_conversion_cast to get the original reinterpret_cast.
// Also see comments in TritonTypeConverter::addConversion.
//
// For TritonToLinalg, we do not use any TypeConverters, hence we
// can access the reinterpret_cast directly.
// we should have already converted to a ReinterpretCastOp without
// layout information for the normal cases, or to an
// UnrealizedConversionCastOp for the split pointer case.
if (v.getDefiningOp<triton::AddPtrOp>() ||
v.getDefiningOp<triton::AdvanceOp>() ||
v.getDefiningOp<triton::MakeTensorPtrOp>()) {
if (auto castOp = mappedV.getDefiningOp<UnrealizedConversionCastOp>()) {
assertValidUnrealizedCast(castOp);
auto castInputs = castOp.getInputs();

assert((castInputs.size() == 1 ||
castOp->hasAttr(ModuloState::WraparoundAttr)) &&
"only expect 1:1 mapping for "
"unrealized_conversion_cast that "
"were "
"automatically inserted during legalizing");

if (castOp->hasAttr(ModuloState::WraparoundAttr)) {
v = castOp.getResult(0);
operands[i] = castInputs[0];
moduloSecondChunks.push_back(castInputs[1]);
} else {
v = castInputs[0];
}

v = castOp.getResult(0);
operands[i] = castInputs[0];
moduloSecondChunks.push_back(castInputs[1]);
} else if (auto castOp =
mappedV.getDefiningOp<memref::ReinterpretCastOp>()) {
v = castOp;
Expand Down Expand Up @@ -937,6 +925,7 @@ void PtrAnalysis::rewriteYieldOp(
visitOperandReintCast(reintCastOp, state, op.getLoc(), rewriter,
knownPtrs);
} else if (unrealizedCastOp) {
assertValidUnrealizedCast(unrealizedCastOp);
visitOperandUnrealizedCast(unrealizedCastOp, state, op.getLoc(), rewriter,
knownPtrs);
} else {
Expand Down Expand Up @@ -998,11 +987,7 @@ void PtrAnalysis::visitOperandUnrealizedCast(
UnrealizedConversionCastOp op, PtrState &state, const Location loc,
ConversionPatternRewriter &rewriter,
const llvm::SmallDenseMap<Value, PtrState> &knownPtrs) {
assert(op->hasAttr(ModuloState::WraparoundAttr) &&
op.getInputs().size() == 3 &&
op.getInputs()[0].getDefiningOp<memref::ReinterpretCastOp>() &&
op.getInputs()[1].getDefiningOp<memref::ReinterpretCastOp>() &&
op.getInputs()[2].getDefiningOp<triton::AddPtrOp>());
assertValidUnrealizedCast(op);

auto origPtr = op.getInputs()[2];
if (knownPtrs.contains(origPtr)) {
Expand Down Expand Up @@ -1051,23 +1036,6 @@ void PtrAnalysis::rewriteForOp(
// Create a new list of init args
for (auto [i, arg] : llvm::enumerate(op.getInitArgs())) {
auto mappedV = rewriter.getRemappedValue(arg);

// Trace back the original value. See comments in rewriteYieldOp.
// This block is unreachable for TritonToLinalg because we don't use
// TypeConverters.
if (mappedV && mappedV.getDefiningOp<UnrealizedConversionCastOp>()) {
auto castOp = mappedV.getDefiningOp<UnrealizedConversionCastOp>();
if (!castOp->hasAttr(ModuloState::WraparoundAttr)) {
assert(castOp && "expected unrealized_conversion_cast");
auto castInputs = castOp.getInputs();
assert(castInputs.size() == 1 &&
"only expect 1:1 mapping for unrealized_conversion_cast "
"that "
"were automatically inserted during legalizing");
mappedV = castInputs[0];
}
}

memref::ReinterpretCastOp reintCastOp;
UnrealizedConversionCastOp unrealizedCastOp;

Expand All @@ -1087,10 +1055,9 @@ void PtrAnalysis::rewriteForOp(
newInitArgs.push_back(mappedV);
} else if (auto op =
mappedV.getDefiningOp<UnrealizedConversionCastOp>()) {
assert(op->hasAttr(ModuloState::WraparoundAttr));
assertValidUnrealizedCast(op);
unrealizedCastOp = op;
auto inputs = unrealizedCastOp.getInputs();
assert(inputs.size() == 3);

SmallVector<ModuloChunkInitArg> initArgData{
ModuloChunkInitArg{inputs[0], i},
Expand Down Expand Up @@ -1341,19 +1308,11 @@ Value PtrAnalysis::getScalarMemRef(Value ptr, Value memRef, const Location loc,
assert(ptr.getType().cast<triton::PointerType>() &&
"expected scalar pointer");

// If pointer is generated by tt.addptr, TypeConverter will have inserted an
// unrealized conversion cast for ptr to cast its type from tt.ptr to unranked
// memref. Input of this cast is the actual source memref.
//
// For TritonToLinalg, we can access the reinterpret_cast directly due to no
// usages of TypeConverters.
// If the pointer is generated by tt.addptr, we will have already inserted an
// ReinterpretCastOp to cast its type from tt.ptr to unranked memref. Return
// the result.
if (ptr.getDefiningOp<triton::AddPtrOp>()) {
if (auto uCast = memRef.getDefiningOp<UnrealizedConversionCastOp>()) {
assert(uCast && "expected unrealized conversion inserted by type "
"converter not found");
return uCast.getInputs()[0];
} else if (auto castOp =
memRef.getDefiningOp<memref::ReinterpretCastOp>()) {
if (auto castOp = memRef.getDefiningOp<memref::ReinterpretCastOp>()) {
return castOp.getResult();
} else {
llvm_unreachable("pointer value is defined by an unexpected op");
Expand Down

0 comments on commit 64bb6c5

Please sign in to comment.