Skip to content

Commit

Permalink
Add support for torch inductor modulo op pattern (#68)
Browse files Browse the repository at this point in the history
This PR addresses several issues:

1. Support modulo pattern generated by torch inductor
We now support `tl.arange(0, size)[:, None] % mod` -- expanding the shape before applying the modulo). Fixes #14 #48.

2. Add more modulo tests running on the CPU backend,
I found out that our current usages of `memref.reinterpret_cast` to support modulo ops in a loop are incorrect. Previously we insert two "no-op" `memref.reinterpret_cast` for the two blocks of the mod pointers so that `LoadConveter` can determine the sizes of the blocks to copy to the local buffers. However, when lowering all the way to llvm, doing this meant that we are resetting the offset of the blocks being yielded in each loop interation. To solve this, I have replaced the casts with the proper `memref.dim_op` to get the correct sizes.

3. Fix individual modulo block's type can sometimes mismatch in a loop
Previously, the types for each individual modulo block can have static strides. During a loop, their corresponding loop's yield values have dynamic strides, causing type mismatch. I have instead make the strides always dynamic to begin with.

4. Support lowering to CPU for more cases
Lowering to memref can produces more affine ops which we would have already run in the current pass ordering. I have added two additional passes in the pass list to fix this issue.

5. Add softmax tutorial test for CPU backend
  • Loading branch information
nhat-nguyen authored Dec 1, 2023
1 parent f90b031 commit bd3ee28
Show file tree
Hide file tree
Showing 12 changed files with 518 additions and 131 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.pyc
6 changes: 2 additions & 4 deletions include/triton-shared/Analysis/MaskAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,11 @@ struct MaskState {
ConversionPatternRewriter &rewriter) const;

std::pair<memref::SubViewOp, memref::SubViewOp>
getSideBySideSubviews(memref::ReinterpretCastOp chunk1,
memref::ReinterpretCastOp chunk2, const Location loc,
getSideBySideSubviews(Value block1, Value block2, const Location loc,
ConversionPatternRewriter &rewriter) const;

std::pair<memref::SubViewOp, memref::SubViewOp>
getStackedSubviews(memref::ReinterpretCastOp chunk1,
memref::ReinterpretCastOp chunk2, const Location loc,
getStackedSubviews(Value block1, Value block2, const Location loc,
ConversionPatternRewriter &rewriter) const;

private:
Expand Down
14 changes: 10 additions & 4 deletions include/triton-shared/Analysis/PtrAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,14 @@ namespace triton {

struct ModuloState {
Value size;
OpFoldResult offset;
ModuloState() {}
ModuloState(Value size, OpFoldResult offset) : size{size}, offset{offset} {}

// offset is used to determine the wraparound point for patterns like:
// offset + (tl.arange(0, 256) % 12)
// The current code assumes that the modulo operator always runs last, e.g:
// (offset + tl.arange(0, 256)) % 12
// This is not used at the moment as there haven't been enough use cases and
// the implementation is quite complex.
// OpFoldResult offset;

static constexpr char const *WraparoundAttr = "ptr.wraparound_type";
static constexpr char const *WraparoundStacked = "stacked";
Expand Down Expand Up @@ -61,7 +66,8 @@ class PtrState {
bool hasModulo() const;

MemRefType getResultMemrefType(MLIRContext *context, int64_t offset,
ArrayRef<int64_t> resultShape) const;
ArrayRef<int64_t> resultShape,
bool useDynamicStrides = false) const;

// Process addition of two PtrStates.
void addState(const PtrState &lhsState, const PtrState &rhsState,
Expand Down
35 changes: 14 additions & 21 deletions lib/Analysis/MaskAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,49 +135,42 @@ static memref::SubViewOp createSubview(Value src, Location loc, OpBuilder &b,
// + colView2 = colFull - colView1
// + rowView1 = rowView2 = row = rowFull
std::pair<memref::SubViewOp, memref::SubViewOp>
MaskState::getSideBySideSubviews(memref::ReinterpretCastOp block1,
memref::ReinterpretCastOp block2,
const Location loc,
MaskState::getSideBySideSubviews(Value block1, Value block2, const Location loc,
ConversionPatternRewriter &rewriter) const {

assert(block1.getResultRank() == 2 && block2.getResultRank() == 2 &&
getRank() == 2);

OpFoldResult subviewRowFull = dims[0];
OpFoldResult subviewColFull = dims[1];
OpFoldResult col1 = block1.getMixedSizes()[1];
OpFoldResult col1 =
rewriter.create<memref::DimOp>(loc, block1, 1).getResult();
OpFoldResult subviewCol1 = minOFRs(col1, subviewColFull, loc, rewriter);
OpFoldResult subviewCol2 =
subOFRs(subviewColFull, subviewCol1, loc, rewriter);

SmallVector<OpFoldResult> offsets(getRank(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> strides = block1.getMixedStrides();
auto sv1 = createSubview(block1.getResult(), loc, rewriter, offsets,
SmallVector<OpFoldResult> strides(getRank(), rewriter.getIndexAttr(1));
auto sv1 = createSubview(block1, loc, rewriter, offsets,
{subviewRowFull, subviewCol1}, strides);
auto sv2 = createSubview(block2.getResult(), loc, rewriter, offsets,
auto sv2 = createSubview(block2, loc, rewriter, offsets,
{subviewRowFull, subviewCol2}, strides);

return {sv1, sv2};
}

std::pair<memref::SubViewOp, memref::SubViewOp> MaskState::getStackedSubviews(
memref::ReinterpretCastOp block1, memref::ReinterpretCastOp block2,
const Location loc, ConversionPatternRewriter &rewriter) const {
assert(block1.getResultRank() == 2 && block2.getResultRank() == 2 &&
getRank() == 2);

std::pair<memref::SubViewOp, memref::SubViewOp>
MaskState::getStackedSubviews(Value block1, Value block2, const Location loc,
ConversionPatternRewriter &rewriter) const {
OpFoldResult subviewRowFull = dims[0];
OpFoldResult subviewColFull = dims[1];
OpFoldResult row1 = block1.getMixedSizes()[0];
OpFoldResult row1 =
rewriter.create<memref::DimOp>(loc, block1, 0).getResult();
OpFoldResult subviewRow1 = minOFRs(row1, subviewRowFull, loc, rewriter);
OpFoldResult subviewRow2 =
subOFRs(subviewRowFull, subviewRow1, loc, rewriter);

SmallVector<OpFoldResult> offsets(getRank(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> strides = block1.getMixedStrides();
auto sv1 = createSubview(block1.getResult(), loc, rewriter, offsets,
SmallVector<OpFoldResult> strides(getRank(), rewriter.getIndexAttr(1));
auto sv1 = createSubview(block1, loc, rewriter, offsets,
{subviewRow1, subviewColFull}, strides);
auto sv2 = createSubview(block2.getResult(), loc, rewriter, offsets,
auto sv2 = createSubview(block2, loc, rewriter, offsets,
{subviewRow2, subviewColFull}, strides);
return {sv1, sv2};
}
Expand Down
73 changes: 52 additions & 21 deletions lib/Analysis/PtrAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,16 @@ static void assertValidUnrealizedCast(UnrealizedConversionCastOp op) {
}

MemRefType PtrState::getResultMemrefType(MLIRContext *context, int64_t offset,
ArrayRef<int64_t> resultShape) const {
ArrayRef<int64_t> resultShape,
bool useDynamicStrides) const {

SmallVector<int64_t> staticStrides;
SmallVector<Value> dynamicStrides;
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
if (useDynamicStrides) {
staticStrides.append(strides.size(), ShapedType::kDynamic);
} else {
SmallVector<Value> dynamicStrides;
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
}

auto elementType = source.getType().cast<BaseMemRefType>().getElementType();
auto layout =
Expand Down Expand Up @@ -196,7 +201,8 @@ PtrState::createStackedCastOps(ArrayRef<int64_t> resultShape,
// the same as the original row. The last chunk
// may be smaller due to wrapping around.
resultShape[1], // Col stays the same.
});
},
true /*useDynamicStrides*/);

Value rowSize = ofrToIndexValue(sizes[0], loc, rewriter);
Value colSize = ofrToIndexValue(sizes[1], loc, rewriter);
Expand Down Expand Up @@ -287,30 +293,33 @@ PtrState::createSideBySideCastOps(ArrayRef<int64_t> resultShape,
ShapedType::kDynamic // Column is dynamic, in most cases, this should
// be the same as the original column. The last
// chunk may be smaller due to wrapping around.
});
},
true /*useDynamicStrides*/);

Value rowSize = ofrToIndexValue(sizes[0], loc, rewriter);
Value colSize = ofrToIndexValue(sizes[1], loc, rewriter);

Value modN = rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
modulos[1]->size);

SmallVector<Value> strideVals = ofrsToIndexValues(strides, loc, rewriter);

Value x = rewriter.create<arith::RemSIOp>(loc, targetOffset, modN);
Value y = rewriter.create<arith::SubIOp>(loc, targetOffset, x);

SmallVector<Value> strideVals = ofrsToIndexValues(strides, loc, rewriter);

// First chunk
Value nextOffset = rewriter.create<arith::AddIOp>(loc, x, colSize);
Value clampedOffset = rewriter.create<arith::MinSIOp>(loc, nextOffset, modN);
Value d1 = rewriter.create<arith::SubIOp>(loc, clampedOffset, x);
SmallVector<Value> sizes1{rowSize, d1};

auto cast1 = rewriter.create<memref::ReinterpretCastOp>(
loc, resultType, source, targetOffset, sizes1, strideVals);

// Second chunk
Value d2 = rewriter.create<arith::SubIOp>(loc, colSize, d1);
SmallVector<Value> sizes2{rowSize, d2};

auto cast2 = rewriter.create<memref::ReinterpretCastOp>(
loc, resultType, source, y, sizes2, strideVals);

Expand Down Expand Up @@ -372,15 +381,46 @@ void PtrAnalysis::visitOperandRem(
ConversionPatternRewriter &rewriter,
const llvm::SmallDenseMap<Value, PtrState> &knownPtrs) {
assert(state.isEmpty());
visitOperand(remOp.getLhs(), state, loc, rewriter, knownPtrs);
assert(state.getRank() == 1 && !state.modulos.back().has_value() &&
"No support for multiple modulos within an expression");

PtrState rhsState;
visitOperand(remOp.getRhs(), rhsState, loc, rewriter, knownPtrs);
assert(rhsState.scalar);

state.modulos.back() = ModuloState(rhsState.scalar, rewriter.getIndexAttr(0));
visitOperand(remOp.getLhs(), state, loc, rewriter, knownPtrs);

// If there are multiple modulo ops on an expression (e.g.: (a % b) % c), we
// would have already populated the modulo states after visiting the lhs.
// Assert that all the modulo states are empty.
assert(llvm::all_of(state.modulos,
[](auto modState) { return !modState.has_value(); }) &&
"No support for multiple modulo within an expression");

if (state.getRank() == 1) {
// Apply the modulo before expanding shape, the common pattern is
// offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
// a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] *
// stride_ak)
state.modulos.back() = ModuloState{rhsState.scalar};
} else if (state.getRank() == 2) {
// torch inductor expands the tensor shape before applying the modulo.
//
// We only support either:
// - (tl.arange(0, end)[:, None] % mod), or
// - (tl.arange(0, end)[None, :] % mod)
//
// In both cases, we apply the modulo to the non-singleton dimension.
auto shape = cast<TensorType>(remOp.getResult().getType()).getShape();
if (shape[0] == 1) {
state.modulos[1] = ModuloState{rhsState.scalar};
} else if (shape[1] == 1) {
state.modulos[0] = ModuloState{rhsState.scalar};
} else {
assert(false && "Taking modulo on a 2D tensor with no singleton "
"dimension not supported");
}
} else {
assert(false && "Unsupported modulo pattern");
}
}

void PtrAnalysis::visitOperandMakeRange(
Expand Down Expand Up @@ -1223,16 +1263,7 @@ void PtrAnalysis::rewriteForOp(
for (auto &[unrealizedCastOp, chunkData, state] : moduloStates) {
SmallVector<Value> newReinterpretCasts;
for (auto &chunk : chunkData) {
auto initReintCast =
chunk.reinterpretCast
.getDefiningOp<memref::ReinterpretCastOp>();

auto newReintCast = b.create<memref::ReinterpretCastOp>(
loc, initReintCast.getResult().getType(),
args[chunk.initArgIndex], zero, initReintCast.getSizes(),
initReintCast.getStrides());

newReinterpretCasts.push_back(newReintCast);
newReinterpretCasts.push_back(args[chunk.initArgIndex]);
}

auto combinedCast = b.create<UnrealizedConversionCastOp>(
Expand Down
110 changes: 62 additions & 48 deletions lib/Conversion/TritonToLinalg/TritonToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -489,62 +489,76 @@ struct LoadConverter : public OpConversionPattern<triton::LoadOp> {
private:
using OpConversionPattern<triton::LoadOp>::OpConversionPattern;

template <typename SourceOpTy>
void createSideBySideCopies(SourceOpTy block1, SourceOpTy block2, Value dst,
void createSideBySideCopies(Value block1, Value block2, Value dst,
Location loc,
ConversionPatternRewriter &rewriter) const {
static_assert((std::is_same<SourceOpTy, memref::ReinterpretCastOp>() ||
std::is_same<SourceOpTy, memref::SubViewOp>()) &&
"Expect source of split pointers to come from either "
"reinterpret_cast or subview ops");

auto zero =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));

auto block1Dst = rewriter.create<memref::SubViewOp>(
loc, dst, /* offsets */ ValueRange{zero, zero},
ofrsToIndexValues(block1.getMixedSizes(), loc, rewriter),
ofrsToIndexValues(block1.getMixedStrides(), loc, rewriter));

auto block2Dst = rewriter.create<memref::SubViewOp>(
loc, dst,
/* offsets */
ValueRange{zero,
ofrToIndexValue(block1.getMixedSizes()[1], loc, rewriter)},
ofrsToIndexValues(block2.getMixedSizes(), loc, rewriter),
ofrsToIndexValues(block2.getMixedStrides(), loc, rewriter));

rewriter.create<memref::CopyOp>(loc, block1.getResult(), block1Dst);
rewriter.create<memref::CopyOp>(loc, block2.getResult(), block2Dst);
auto one =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));

Value block1Row = rewriter.create<memref::DimOp>(loc, block1, 0);
Value block1Col = rewriter.create<memref::DimOp>(loc, block1, 1);

Value block2Row = rewriter.create<memref::DimOp>(loc, block2, 0);
Value block2Col = rewriter.create<memref::DimOp>(loc, block2, 1);

auto block1Dst =
rewriter.create<memref::SubViewOp>(loc, dst, /* offsets */
ValueRange{zero, zero},
/* sizes */
ValueRange{block1Row, block1Col},
/* strides */
ValueRange{one, one});

auto block2Dst =
rewriter.create<memref::SubViewOp>(loc, dst,
/* offsets */
ValueRange{zero, block1Col},
/* sizes */
ValueRange{block2Row, block2Col},
/* strides */
ValueRange{one, one});

rewriter.create<memref::CopyOp>(loc, block1, block1Dst);
rewriter.create<memref::CopyOp>(loc, block2, block2Dst);
}

template <typename SourceOpTy>
void createStackedCopies(SourceOpTy block1, SourceOpTy block2, Value dst,
Location loc,
void createStackedCopies(Value block1, Value block2, Value dst, Location loc,
ConversionPatternRewriter &rewriter) const {
static_assert((std::is_same<SourceOpTy, memref::ReinterpretCastOp>() ||
std::is_same<SourceOpTy, memref::SubViewOp>()) &&
"Expect source of split pointers to come from either "
"reinterpret_cast or subview ops");

auto zero =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));

auto block1Dst = rewriter.create<memref::SubViewOp>(
loc, dst, /* offsets */ ValueRange{zero, zero},
ofrsToIndexValues(block1.getMixedSizes(), loc, rewriter),
ofrsToIndexValues(block1.getMixedStrides(), loc, rewriter));

auto block2Dst = rewriter.create<memref::SubViewOp>(
loc, dst,
/* offsets */
ValueRange{ofrToIndexValue(block1.getMixedSizes()[0], loc, rewriter),
zero},
ofrsToIndexValues(block2.getMixedSizes(), loc, rewriter),
ofrsToIndexValues(block2.getMixedStrides(), loc, rewriter));

rewriter.create<memref::CopyOp>(loc, block1.getResult(), block1Dst);
rewriter.create<memref::CopyOp>(loc, block2.getResult(), block2Dst);
auto one =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));

Value block1Row = rewriter.create<memref::DimOp>(loc, block1, 0);
Value block1Col = rewriter.create<memref::DimOp>(loc, block1, 1);

Value block2Row = rewriter.create<memref::DimOp>(loc, block2, 0);
Value block2Col = rewriter.create<memref::DimOp>(loc, block2, 1);

auto block1Dst =
rewriter.create<memref::SubViewOp>(loc, dst, /* offsets */
ValueRange{zero, zero},
/* sizes */
ValueRange{block1Row, block1Col},
/* strides */
ValueRange{one, one});

auto block2Dst =
rewriter.create<memref::SubViewOp>(loc, dst,
/* offsets */
ValueRange{block1Row, zero},
/* sizes */
ValueRange{block2Row, block2Col},
/* strides */
ValueRange{one, one});

rewriter.create<memref::CopyOp>(loc, block1, block1Dst);
rewriter.create<memref::CopyOp>(loc, block2, block2Dst);
}

public:
Expand Down Expand Up @@ -591,8 +605,8 @@ struct LoadConverter : public OpConversionPattern<triton::LoadOp> {
ModuloState::WraparoundAttr)) {

auto memrefs = unrealizedCast.getOperands();
auto block1 = memrefs[0].getDefiningOp<memref::ReinterpretCastOp>();
auto block2 = memrefs[1].getDefiningOp<memref::ReinterpretCastOp>();
auto block1 = memrefs[0];
auto block2 = memrefs[1];

if (wrapType.getValue().equals(ModuloState::WraparoundSideBySide)) {
createSideBySideCopies(block1, block2, alloc, loc, rewriter);
Expand Down Expand Up @@ -671,8 +685,8 @@ struct LoadConverter : public OpConversionPattern<triton::LoadOp> {
ModuloState::WraparoundAttr)) {

auto memrefs = unrealizedCast.getOperands();
auto block1 = memrefs[0].getDefiningOp<memref::ReinterpretCastOp>();
auto block2 = memrefs[1].getDefiningOp<memref::ReinterpretCastOp>();
auto block1 = memrefs[0];
auto block2 = memrefs[1];

if (wrapType.getValue().equals(ModuloState::WraparoundSideBySide)) {
auto [subview1, subview2] =
Expand Down
Loading

0 comments on commit bd3ee28

Please sign in to comment.