From bd3ee28c242eeae1fede4acc3cc09721c990e9d1 Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Fri, 1 Dec 2023 22:51:26 +0700 Subject: [PATCH] Add support for torch inductor modulo op pattern (#68) This PR addresses several issues: 1. Support modulo pattern generated by torch inductor We now support `tl.arange(0, size)[:, None] % mod` -- expanding the shape before applying the modulo). Fixes #14 #48. 2. Add more modulo tests running on the CPU backend, I found out that our current usages of `memref.reinterpret_cast` to support modulo ops in a loop are incorrect. Previously we insert two "no-op" `memref.reinterpret_cast` for the two blocks of the mod pointers so that `LoadConveter` can determine the sizes of the blocks to copy to the local buffers. However, when lowering all the way to llvm, doing this meant that we are resetting the offset of the blocks being yielded in each loop interation. To solve this, I have replaced the casts with the proper `memref.dim_op` to get the correct sizes. 3. Fix individual modulo block's type can sometimes mismatch in a loop Previously, the types for each individual modulo block can have static strides. During a loop, their corresponding loop's yield values have dynamic strides, causing type mismatch. I have instead make the strides always dynamic to begin with. 4. Support lowering to CPU for more cases Lowering to memref can produces more affine ops which we would have already run in the current pass ordering. I have added two additional passes in the pass list to fix this issue. 5. Add softmax tutorial test for CPU backend --- .gitignore | 1 + include/triton-shared/Analysis/MaskAnalysis.h | 6 +- include/triton-shared/Analysis/PtrAnalysis.h | 14 +- lib/Analysis/MaskAnalysis.cpp | 35 +-- lib/Analysis/PtrAnalysis.cpp | 73 +++-- .../TritonToLinalg/TritonToLinalg.cpp | 110 +++---- python/__init__.py | 6 + python/examples/test_modulo.py | 272 ++++++++++++++++++ python/examples/test_reduce.py | 2 + python/examples/test_softmax.py | 65 +++++ .../wraparound_side_by_side.mlir | 30 +- .../TritonToLinalg/wraparound_stacked.mlir | 35 ++- 12 files changed, 518 insertions(+), 131 deletions(-) create mode 100644 .gitignore create mode 100644 python/examples/test_modulo.py create mode 100644 python/examples/test_softmax.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..0d20b648 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +*.pyc diff --git a/include/triton-shared/Analysis/MaskAnalysis.h b/include/triton-shared/Analysis/MaskAnalysis.h index 65170073..531598e6 100644 --- a/include/triton-shared/Analysis/MaskAnalysis.h +++ b/include/triton-shared/Analysis/MaskAnalysis.h @@ -67,13 +67,11 @@ struct MaskState { ConversionPatternRewriter &rewriter) const; std::pair - getSideBySideSubviews(memref::ReinterpretCastOp chunk1, - memref::ReinterpretCastOp chunk2, const Location loc, + getSideBySideSubviews(Value block1, Value block2, const Location loc, ConversionPatternRewriter &rewriter) const; std::pair - getStackedSubviews(memref::ReinterpretCastOp chunk1, - memref::ReinterpretCastOp chunk2, const Location loc, + getStackedSubviews(Value block1, Value block2, const Location loc, ConversionPatternRewriter &rewriter) const; private: diff --git a/include/triton-shared/Analysis/PtrAnalysis.h b/include/triton-shared/Analysis/PtrAnalysis.h index a96ee8cb..5a95ebda 100644 --- a/include/triton-shared/Analysis/PtrAnalysis.h +++ b/include/triton-shared/Analysis/PtrAnalysis.h @@ -24,9 +24,14 @@ namespace triton { struct ModuloState { Value size; - OpFoldResult offset; - ModuloState() {} - ModuloState(Value size, OpFoldResult offset) : size{size}, offset{offset} {} + + // offset is used to determine the wraparound point for patterns like: + // offset + (tl.arange(0, 256) % 12) + // The current code assumes that the modulo operator always runs last, e.g: + // (offset + tl.arange(0, 256)) % 12 + // This is not used at the moment as there haven't been enough use cases and + // the implementation is quite complex. + // OpFoldResult offset; static constexpr char const *WraparoundAttr = "ptr.wraparound_type"; static constexpr char const *WraparoundStacked = "stacked"; @@ -61,7 +66,8 @@ class PtrState { bool hasModulo() const; MemRefType getResultMemrefType(MLIRContext *context, int64_t offset, - ArrayRef resultShape) const; + ArrayRef resultShape, + bool useDynamicStrides = false) const; // Process addition of two PtrStates. void addState(const PtrState &lhsState, const PtrState &rhsState, diff --git a/lib/Analysis/MaskAnalysis.cpp b/lib/Analysis/MaskAnalysis.cpp index 2e54a151..18af487a 100644 --- a/lib/Analysis/MaskAnalysis.cpp +++ b/lib/Analysis/MaskAnalysis.cpp @@ -135,49 +135,42 @@ static memref::SubViewOp createSubview(Value src, Location loc, OpBuilder &b, // + colView2 = colFull - colView1 // + rowView1 = rowView2 = row = rowFull std::pair -MaskState::getSideBySideSubviews(memref::ReinterpretCastOp block1, - memref::ReinterpretCastOp block2, - const Location loc, +MaskState::getSideBySideSubviews(Value block1, Value block2, const Location loc, ConversionPatternRewriter &rewriter) const { - - assert(block1.getResultRank() == 2 && block2.getResultRank() == 2 && - getRank() == 2); - OpFoldResult subviewRowFull = dims[0]; OpFoldResult subviewColFull = dims[1]; - OpFoldResult col1 = block1.getMixedSizes()[1]; + OpFoldResult col1 = + rewriter.create(loc, block1, 1).getResult(); OpFoldResult subviewCol1 = minOFRs(col1, subviewColFull, loc, rewriter); OpFoldResult subviewCol2 = subOFRs(subviewColFull, subviewCol1, loc, rewriter); SmallVector offsets(getRank(), rewriter.getIndexAttr(0)); - SmallVector strides = block1.getMixedStrides(); - auto sv1 = createSubview(block1.getResult(), loc, rewriter, offsets, + SmallVector strides(getRank(), rewriter.getIndexAttr(1)); + auto sv1 = createSubview(block1, loc, rewriter, offsets, {subviewRowFull, subviewCol1}, strides); - auto sv2 = createSubview(block2.getResult(), loc, rewriter, offsets, + auto sv2 = createSubview(block2, loc, rewriter, offsets, {subviewRowFull, subviewCol2}, strides); return {sv1, sv2}; } -std::pair MaskState::getStackedSubviews( - memref::ReinterpretCastOp block1, memref::ReinterpretCastOp block2, - const Location loc, ConversionPatternRewriter &rewriter) const { - assert(block1.getResultRank() == 2 && block2.getResultRank() == 2 && - getRank() == 2); - +std::pair +MaskState::getStackedSubviews(Value block1, Value block2, const Location loc, + ConversionPatternRewriter &rewriter) const { OpFoldResult subviewRowFull = dims[0]; OpFoldResult subviewColFull = dims[1]; - OpFoldResult row1 = block1.getMixedSizes()[0]; + OpFoldResult row1 = + rewriter.create(loc, block1, 0).getResult(); OpFoldResult subviewRow1 = minOFRs(row1, subviewRowFull, loc, rewriter); OpFoldResult subviewRow2 = subOFRs(subviewRowFull, subviewRow1, loc, rewriter); SmallVector offsets(getRank(), rewriter.getIndexAttr(0)); - SmallVector strides = block1.getMixedStrides(); - auto sv1 = createSubview(block1.getResult(), loc, rewriter, offsets, + SmallVector strides(getRank(), rewriter.getIndexAttr(1)); + auto sv1 = createSubview(block1, loc, rewriter, offsets, {subviewRow1, subviewColFull}, strides); - auto sv2 = createSubview(block2.getResult(), loc, rewriter, offsets, + auto sv2 = createSubview(block2, loc, rewriter, offsets, {subviewRow2, subviewColFull}, strides); return {sv1, sv2}; } diff --git a/lib/Analysis/PtrAnalysis.cpp b/lib/Analysis/PtrAnalysis.cpp index 894f5300..e92126e3 100644 --- a/lib/Analysis/PtrAnalysis.cpp +++ b/lib/Analysis/PtrAnalysis.cpp @@ -29,11 +29,16 @@ static void assertValidUnrealizedCast(UnrealizedConversionCastOp op) { } MemRefType PtrState::getResultMemrefType(MLIRContext *context, int64_t offset, - ArrayRef resultShape) const { + ArrayRef resultShape, + bool useDynamicStrides) const { SmallVector staticStrides; - SmallVector dynamicStrides; - dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); + if (useDynamicStrides) { + staticStrides.append(strides.size(), ShapedType::kDynamic); + } else { + SmallVector dynamicStrides; + dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); + } auto elementType = source.getType().cast().getElementType(); auto layout = @@ -196,7 +201,8 @@ PtrState::createStackedCastOps(ArrayRef resultShape, // the same as the original row. The last chunk // may be smaller due to wrapping around. resultShape[1], // Col stays the same. - }); + }, + true /*useDynamicStrides*/); Value rowSize = ofrToIndexValue(sizes[0], loc, rewriter); Value colSize = ofrToIndexValue(sizes[1], loc, rewriter); @@ -287,7 +293,8 @@ PtrState::createSideBySideCastOps(ArrayRef resultShape, ShapedType::kDynamic // Column is dynamic, in most cases, this should // be the same as the original column. The last // chunk may be smaller due to wrapping around. - }); + }, + true /*useDynamicStrides*/); Value rowSize = ofrToIndexValue(sizes[0], loc, rewriter); Value colSize = ofrToIndexValue(sizes[1], loc, rewriter); @@ -295,22 +302,24 @@ PtrState::createSideBySideCastOps(ArrayRef resultShape, Value modN = rewriter.create(loc, rewriter.getIndexType(), modulos[1]->size); - SmallVector strideVals = ofrsToIndexValues(strides, loc, rewriter); - Value x = rewriter.create(loc, targetOffset, modN); Value y = rewriter.create(loc, targetOffset, x); + SmallVector strideVals = ofrsToIndexValues(strides, loc, rewriter); + // First chunk Value nextOffset = rewriter.create(loc, x, colSize); Value clampedOffset = rewriter.create(loc, nextOffset, modN); Value d1 = rewriter.create(loc, clampedOffset, x); SmallVector sizes1{rowSize, d1}; + auto cast1 = rewriter.create( loc, resultType, source, targetOffset, sizes1, strideVals); // Second chunk Value d2 = rewriter.create(loc, colSize, d1); SmallVector sizes2{rowSize, d2}; + auto cast2 = rewriter.create( loc, resultType, source, y, sizes2, strideVals); @@ -372,15 +381,46 @@ void PtrAnalysis::visitOperandRem( ConversionPatternRewriter &rewriter, const llvm::SmallDenseMap &knownPtrs) { assert(state.isEmpty()); - visitOperand(remOp.getLhs(), state, loc, rewriter, knownPtrs); - assert(state.getRank() == 1 && !state.modulos.back().has_value() && - "No support for multiple modulos within an expression"); PtrState rhsState; visitOperand(remOp.getRhs(), rhsState, loc, rewriter, knownPtrs); assert(rhsState.scalar); - state.modulos.back() = ModuloState(rhsState.scalar, rewriter.getIndexAttr(0)); + visitOperand(remOp.getLhs(), state, loc, rewriter, knownPtrs); + + // If there are multiple modulo ops on an expression (e.g.: (a % b) % c), we + // would have already populated the modulo states after visiting the lhs. + // Assert that all the modulo states are empty. + assert(llvm::all_of(state.modulos, + [](auto modState) { return !modState.has_value(); }) && + "No support for multiple modulo within an expression"); + + if (state.getRank() == 1) { + // Apply the modulo before expanding shape, the common pattern is + // offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + // a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * + // stride_ak) + state.modulos.back() = ModuloState{rhsState.scalar}; + } else if (state.getRank() == 2) { + // torch inductor expands the tensor shape before applying the modulo. + // + // We only support either: + // - (tl.arange(0, end)[:, None] % mod), or + // - (tl.arange(0, end)[None, :] % mod) + // + // In both cases, we apply the modulo to the non-singleton dimension. + auto shape = cast(remOp.getResult().getType()).getShape(); + if (shape[0] == 1) { + state.modulos[1] = ModuloState{rhsState.scalar}; + } else if (shape[1] == 1) { + state.modulos[0] = ModuloState{rhsState.scalar}; + } else { + assert(false && "Taking modulo on a 2D tensor with no singleton " + "dimension not supported"); + } + } else { + assert(false && "Unsupported modulo pattern"); + } } void PtrAnalysis::visitOperandMakeRange( @@ -1223,16 +1263,7 @@ void PtrAnalysis::rewriteForOp( for (auto &[unrealizedCastOp, chunkData, state] : moduloStates) { SmallVector newReinterpretCasts; for (auto &chunk : chunkData) { - auto initReintCast = - chunk.reinterpretCast - .getDefiningOp(); - - auto newReintCast = b.create( - loc, initReintCast.getResult().getType(), - args[chunk.initArgIndex], zero, initReintCast.getSizes(), - initReintCast.getStrides()); - - newReinterpretCasts.push_back(newReintCast); + newReinterpretCasts.push_back(args[chunk.initArgIndex]); } auto combinedCast = b.create( diff --git a/lib/Conversion/TritonToLinalg/TritonToLinalg.cpp b/lib/Conversion/TritonToLinalg/TritonToLinalg.cpp index e50907d9..66acd25d 100644 --- a/lib/Conversion/TritonToLinalg/TritonToLinalg.cpp +++ b/lib/Conversion/TritonToLinalg/TritonToLinalg.cpp @@ -489,62 +489,76 @@ struct LoadConverter : public OpConversionPattern { private: using OpConversionPattern::OpConversionPattern; - template - void createSideBySideCopies(SourceOpTy block1, SourceOpTy block2, Value dst, + void createSideBySideCopies(Value block1, Value block2, Value dst, Location loc, ConversionPatternRewriter &rewriter) const { - static_assert((std::is_same() || - std::is_same()) && - "Expect source of split pointers to come from either " - "reinterpret_cast or subview ops"); auto zero = rewriter.create(loc, rewriter.getIndexAttr(0)); - auto block1Dst = rewriter.create( - loc, dst, /* offsets */ ValueRange{zero, zero}, - ofrsToIndexValues(block1.getMixedSizes(), loc, rewriter), - ofrsToIndexValues(block1.getMixedStrides(), loc, rewriter)); - - auto block2Dst = rewriter.create( - loc, dst, - /* offsets */ - ValueRange{zero, - ofrToIndexValue(block1.getMixedSizes()[1], loc, rewriter)}, - ofrsToIndexValues(block2.getMixedSizes(), loc, rewriter), - ofrsToIndexValues(block2.getMixedStrides(), loc, rewriter)); - - rewriter.create(loc, block1.getResult(), block1Dst); - rewriter.create(loc, block2.getResult(), block2Dst); + auto one = + rewriter.create(loc, rewriter.getIndexAttr(1)); + + Value block1Row = rewriter.create(loc, block1, 0); + Value block1Col = rewriter.create(loc, block1, 1); + + Value block2Row = rewriter.create(loc, block2, 0); + Value block2Col = rewriter.create(loc, block2, 1); + + auto block1Dst = + rewriter.create(loc, dst, /* offsets */ + ValueRange{zero, zero}, + /* sizes */ + ValueRange{block1Row, block1Col}, + /* strides */ + ValueRange{one, one}); + + auto block2Dst = + rewriter.create(loc, dst, + /* offsets */ + ValueRange{zero, block1Col}, + /* sizes */ + ValueRange{block2Row, block2Col}, + /* strides */ + ValueRange{one, one}); + + rewriter.create(loc, block1, block1Dst); + rewriter.create(loc, block2, block2Dst); } - template - void createStackedCopies(SourceOpTy block1, SourceOpTy block2, Value dst, - Location loc, + void createStackedCopies(Value block1, Value block2, Value dst, Location loc, ConversionPatternRewriter &rewriter) const { - static_assert((std::is_same() || - std::is_same()) && - "Expect source of split pointers to come from either " - "reinterpret_cast or subview ops"); auto zero = rewriter.create(loc, rewriter.getIndexAttr(0)); - - auto block1Dst = rewriter.create( - loc, dst, /* offsets */ ValueRange{zero, zero}, - ofrsToIndexValues(block1.getMixedSizes(), loc, rewriter), - ofrsToIndexValues(block1.getMixedStrides(), loc, rewriter)); - - auto block2Dst = rewriter.create( - loc, dst, - /* offsets */ - ValueRange{ofrToIndexValue(block1.getMixedSizes()[0], loc, rewriter), - zero}, - ofrsToIndexValues(block2.getMixedSizes(), loc, rewriter), - ofrsToIndexValues(block2.getMixedStrides(), loc, rewriter)); - - rewriter.create(loc, block1.getResult(), block1Dst); - rewriter.create(loc, block2.getResult(), block2Dst); + auto one = + rewriter.create(loc, rewriter.getIndexAttr(1)); + + Value block1Row = rewriter.create(loc, block1, 0); + Value block1Col = rewriter.create(loc, block1, 1); + + Value block2Row = rewriter.create(loc, block2, 0); + Value block2Col = rewriter.create(loc, block2, 1); + + auto block1Dst = + rewriter.create(loc, dst, /* offsets */ + ValueRange{zero, zero}, + /* sizes */ + ValueRange{block1Row, block1Col}, + /* strides */ + ValueRange{one, one}); + + auto block2Dst = + rewriter.create(loc, dst, + /* offsets */ + ValueRange{block1Row, zero}, + /* sizes */ + ValueRange{block2Row, block2Col}, + /* strides */ + ValueRange{one, one}); + + rewriter.create(loc, block1, block1Dst); + rewriter.create(loc, block2, block2Dst); } public: @@ -591,8 +605,8 @@ struct LoadConverter : public OpConversionPattern { ModuloState::WraparoundAttr)) { auto memrefs = unrealizedCast.getOperands(); - auto block1 = memrefs[0].getDefiningOp(); - auto block2 = memrefs[1].getDefiningOp(); + auto block1 = memrefs[0]; + auto block2 = memrefs[1]; if (wrapType.getValue().equals(ModuloState::WraparoundSideBySide)) { createSideBySideCopies(block1, block2, alloc, loc, rewriter); @@ -671,8 +685,8 @@ struct LoadConverter : public OpConversionPattern { ModuloState::WraparoundAttr)) { auto memrefs = unrealizedCast.getOperands(); - auto block1 = memrefs[0].getDefiningOp(); - auto block2 = memrefs[1].getDefiningOp(); + auto block1 = memrefs[0]; + auto block2 = memrefs[1]; if (wrapType.getValue().equals(ModuloState::WraparoundSideBySide)) { auto [subview1, subview2] = diff --git a/python/__init__.py b/python/__init__.py index 35269099..23c20df7 100644 --- a/python/__init__.py +++ b/python/__init__.py @@ -67,6 +67,12 @@ def _ttsharedir_to_llir(ttsharedir: str): "--expand-strided-metadata", "--finalize-memref-to-llvm", "--convert-func-to-llvm", + # Lowering memrefs creates more affine.apply ops. + # Lowering these affine ops again creates further arith ops, + # so we have to run these two passes again here. + "--lower-affine", + "--convert-arith-to-llvm", + # Remove all unrealized casts created "--reconcile-unrealized-casts", "-o", llmlir_path]) diff --git a/python/examples/test_modulo.py b/python/examples/test_modulo.py new file mode 100644 index 00000000..da3c02cc --- /dev/null +++ b/python/examples/test_modulo.py @@ -0,0 +1,272 @@ +import torch + +import triton +import triton.language as tl + + +@triton.jit +def wrap_stacked_masked_loop( + a_ptr, c_ptr, M, N, stride_am, stride_an, stride_cm, stride_cn, BLOCK_SIZE_K: tl.constexpr +): + offs_am = (2 + tl.arange(0, BLOCK_SIZE_K)) % M + offs_an = 3 + tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_an[None, :] * stride_an) + + offs_cm = tl.arange(0, BLOCK_SIZE_K) + offs_cn = tl.arange(0, BLOCK_SIZE_K) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + + offs_k = tl.arange(0, BLOCK_SIZE_K) + + for k in range(0, 2): + a = tl.load(a_ptrs, mask=offs_k[None, :] < 3, other=-99) + tl.store(c_ptrs, a) + a_ptrs += BLOCK_SIZE_K * stride_an + c_ptrs += BLOCK_SIZE_K * stride_an + + +@triton.jit +def wrap_side_by_side_masked_loop( + a_ptr, c_ptr, M, N, stride_am, stride_an, stride_cm, stride_cn, BLOCK_SIZE_K: tl.constexpr +): + offs_am = 2 + tl.arange(0, BLOCK_SIZE_K) + offs_an = (6 + tl.arange(0, BLOCK_SIZE_K)) % N + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_an[None, :] * stride_an) + + offs_k = tl.arange(0, BLOCK_SIZE_K) + + offs_cm = tl.arange(0, BLOCK_SIZE_K) + offs_cn = tl.arange(0, BLOCK_SIZE_K) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + + for k in range(0, 2): + a = tl.load(a_ptrs, mask=offs_k[:, None] < 2, other=-99) + tl.store(c_ptrs, a) + a_ptrs += BLOCK_SIZE_K * stride_am + c_ptrs += BLOCK_SIZE_K * stride_an + + +@triton.jit +def wrap_side_by_side_loop( + a_ptr, c_ptr, M, N, stride_am, stride_an, stride_cm, stride_cn, BLOCK_SIZE_K: tl.constexpr +): + offs_am = tl.arange(0, 4) + offs_an = (6 + tl.arange(0, 4)) % N + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_an[None, :] * stride_an) + + offs_cm = tl.arange(0, 4) + offs_cn = tl.arange(0, 4) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + + for k in range(0, 3): + a = tl.load(a_ptrs) + tl.store(c_ptrs, a) + a_ptrs += BLOCK_SIZE_K * stride_am + c_ptrs += BLOCK_SIZE_K * stride_am + + + +@triton.jit +def wrap_side_by_side_loop_unroll( + a_ptr, c_ptr, M, N, stride_am, stride_an, stride_cm, stride_cn, BLOCK_SIZE_K: tl.constexpr +): + offs_am = tl.arange(0, 4) + offs_an = (6 + tl.arange(0, 4)) % N + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_an[None, :] * stride_an) + + offs_cm = tl.arange(0, 4) + offs_cn = tl.arange(0, 4) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + + a = tl.load(a_ptrs) + tl.store(c_ptrs, a) + a_ptrs += BLOCK_SIZE_K * stride_am + c_ptrs += BLOCK_SIZE_K * stride_an + + a = tl.load(a_ptrs) + tl.store(c_ptrs, a) + a_ptrs += BLOCK_SIZE_K * stride_am + c_ptrs += BLOCK_SIZE_K * stride_an + + a = tl.load(a_ptrs) + tl.store(c_ptrs, a) + a_ptrs += BLOCK_SIZE_K * stride_am + c_ptrs += BLOCK_SIZE_K * stride_an + + +@triton.jit +def mod_1d( + a_ptr, c_ptr, M, N, stride_am, stride_an, stride_cm, stride_cn, BLOCK_SIZE_K: tl.constexpr +): + row = 7 + offs_an = (6 + tl.arange(0, 4)) % N + a_ptrs = a_ptr + (row * stride_am) + offs_an[None, :] * stride_an + + offs_cn = tl.arange(0, 4) + c_ptrs = c_ptr + stride_cn * offs_cn[None, :] + + a = tl.load(a_ptrs) + tl.store(c_ptrs, a) + + +@triton.jit +def mod_2d( + a_ptr, c_ptr, M, N, stride_am, stride_an, stride_cm, stride_cn, BLOCK_SIZE_K: tl.constexpr +): + offs_am = 2 + tl.arange(0, 4) + offs_an = (6 + tl.arange(0, 4)) % N + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_an[None, :] * stride_an) + + offs_cm = tl.arange(0, 4) + offs_cn = tl.arange(0, 4) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + + a = tl.load(a_ptrs) + tl.store(c_ptrs, a) + + +def test_side_by_side(): + M = 12 + N = 8 + A = torch.arange(0, M * N, device="cpu", dtype=torch.float32).reshape((M, N)) + out = torch.full((M, N), 88888, device="cpu", dtype=torch.float32) + print(out) + grid = lambda meta: (1,) + + wrap_side_by_side_masked_loop[grid]( + A, + out, + M, + N, + A.stride(0), + A.stride(1), + out.stride(0), + out.stride(1), + BLOCK_SIZE_K=4 + ) + + # Expected output copied from running triton on NVDIA gpu + expected_out = torch.tensor([[ 22, 23, 16, 17, 54, 55, 48, 49], + [ 30, 31, 24, 25, 62, 63, 56, 57], + [ -99, -99, -99, -99, -99, -99, -99, -99], + [ -99, -99, -99, -99, -99, -99, -99, -99], + [88888, 88888, 88888, 88888, 88888, 88888, 88888, 88888], + [88888, 88888, 88888, 88888, 88888, 88888, 88888, 88888], + [88888, 88888, 88888, 88888, 88888, 88888, 88888, 88888], + [88888, 88888, 88888, 88888, 88888, 88888, 88888, 88888], + [88888, 88888, 88888, 88888, 88888, 88888, 88888, 88888], + [88888, 88888, 88888, 88888, 88888, 88888, 88888, 88888], + [88888, 88888, 88888, 88888, 88888, 88888, 88888, 88888], + [88888, 88888, 88888, 88888, 88888, 88888, 88888, 88888]], dtype=torch.int32) + + + print(A) + print(out.int()) + assert torch.equal(expected_out.int(), out.int()) + print('Hooooray') + + +def test_stacked(): + M = 4 + N = 12 + BLOCK_SIZE_M = 4 + BLOCK_SIZE_N = 4 + A = torch.arange(0, M * N, device="cpu", dtype=torch.float32).reshape((M, N)) + out = torch.full((BLOCK_SIZE_M, N), 88888, device="cpu", dtype=torch.float32) + print(out) + grid = lambda meta: (1,) + + wrap_stacked_masked_loop[grid]( + A, + out, + M, + N, + A.stride(0), + A.stride(1), + out.stride(0), + out.stride(1), + BLOCK_SIZE_K=4 + ) + + # Expected output copied from running triton on NVDIA gpu + expected_out = torch.tensor( + [ + [27.0, 28.0, 29.0, -99.0, 31.0, 32.0, 33.0, -99.0, 88888, 88888, 88888, 88888,], + [39.0, 40.0, 41.0, -99.0, 43.0, 44.0, 45.0, -99.0, 88888, 88888, 88888, 88888,], + [3.0, 4.0, 5.0, -99.0, 7.0, 8.0, 9.0, -99.0, 88888, 88888, 88888, 88888,], + [15.0, 16.0, 17.0, -99.0, 19.0, 20.0, 21.0, -99.0, 88888, 88888, 88888, 88888,], + ], + ) + + print(A) + print(out.int()) + assert torch.equal(expected_out.int(), out.int()) + print('Passed') + + +def test_torch_inductor_pattern(): + def compile(): + ret = triton.compile(triton_, signature="*i32,*i32,i32,i32,i32", constants={"XBLOCK": 64, "RBLOCK": 64}) + print(ret.asm["ttir"]) + + + @triton.jit + def triton_(in_ptr2, out_ptr2, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): + xnumel = 128 + rnumel = 32 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + rbase = tl.arange(0, RBLOCK)[None, :] + x0 = xindex % 7 + x0 = xindex + roffset = 0 + rindex = roffset + rbase + rmask = rindex < rnumel + r2 = rindex + tmp3 = tl.load(in_ptr2 + (r2 + (xnumel*x0)), rmask, other=77) + tl.store(out_ptr2 + (XBLOCK * tl.arange(0, RBLOCK)[None, :] + tl.arange(0, XBLOCK)[:, None]), tmp3) + + device = "cpu" + xnumel = 128 + rnumel = 32 + + XBLOCK = 4 + RBLOCK = 64 + A = torch.arange(0, xnumel * rnumel, device=device, dtype=torch.int32).reshape((xnumel, rnumel)) + out = torch.full((XBLOCK, RBLOCK), 88888, device=device, dtype=torch.int32) + grid = lambda meta: (1,) + + triton_[grid]( + A, + out, + rnumel, + XBLOCK=XBLOCK, + RBLOCK=RBLOCK + ) + + # Expected output copied from running triton on NVDIA gpu + expected_out = torch.tensor([[ 0, 128, 256, 384, 1, 129, 257, 385, 2, 130, 258, 386, 3, 131, + 259, 387, 4, 132, 260, 388, 5, 133, 261, 389, 6, 134, 262, 390, + 7, 135, 263, 391, 8, 136, 264, 392, 9, 137, 265, 393, 10, 138, + 266, 394, 11, 139, 267, 395, 12, 140, 268, 396, 13, 141, 269, 397, + 14, 142, 270, 398, 15, 143, 271, 399], + [ 16, 144, 272, 400, 17, 145, 273, 401, 18, 146, 274, 402, 19, 147, + 275, 403, 20, 148, 276, 404, 21, 149, 277, 405, 22, 150, 278, 406, + 23, 151, 279, 407, 24, 152, 280, 408, 25, 153, 281, 409, 26, 154, + 282, 410, 27, 155, 283, 411, 28, 156, 284, 412, 29, 157, 285, 413, + 30, 158, 286, 414, 31, 159, 287, 415], + [ 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, + 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, + 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, + 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, + 77, 77, 77, 77, 77, 77, 77, 77], + [ 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, + 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, + 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, + 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, + 77, 77, 77, 77, 77, 77, 77, 77]], device=device, + dtype=torch.int32) + + print(out) + assert torch.equal(expected_out.int(), out.int()) + print('Passed') diff --git a/python/examples/test_reduce.py b/python/examples/test_reduce.py index 28a40a2c..6524097a 100644 --- a/python/examples/test_reduce.py +++ b/python/examples/test_reduce.py @@ -51,3 +51,5 @@ def test(): print(ret.asm["ttsharedir"]) print(ret.asm["llir"]) print(ret.asm["cpuasm"]) + +test() diff --git a/python/examples/test_softmax.py b/python/examples/test_softmax.py new file mode 100644 index 00000000..e07055fa --- /dev/null +++ b/python/examples/test_softmax.py @@ -0,0 +1,65 @@ +import torch + +import triton +import triton.language as tl + +import torch + +@triton.jit +def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE: tl.constexpr): + # The rows of the softmax are independent, so we parallelize across those + row_idx = tl.program_id(0) + # The stride represents how much we need to increase the pointer to advance 1 row + row_start_ptr = input_ptr + row_idx * input_row_stride + # The block size is the next power of two greater than n_cols, so we can fit each + # row in a single block + col_offsets = tl.arange(0, BLOCK_SIZE) + input_ptrs = row_start_ptr + col_offsets + # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols + row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')) + # Subtract maximum for numerical stability + row_minus_max = row - tl.max(row, axis=0) + # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA) + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=0) + softmax_output = numerator / denominator + # Write back output to DRAM + output_row_start_ptr = output_ptr + row_idx * output_row_stride + output_ptrs = output_row_start_ptr + col_offsets + tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols) + + +def softmax(x): + n_rows, n_cols = x.shape + # The block size is the smallest power of two greater than the number of columns in `x` + BLOCK_SIZE = triton.next_power_of_2(n_cols) + # Another trick we can use is to ask the compiler to use more threads per row by + # increasing the number of warps (`num_warps`) over which each row is distributed. + # You will see in the next tutorial how to auto-tune this value in a more natural + # way so you don't have to come up with manual heuristics yourself. + num_warps = 4 + if BLOCK_SIZE >= 2048: + num_warps = 8 + if BLOCK_SIZE >= 4096: + num_warps = 16 + # Allocate output + y = torch.empty_like(x) + # Enqueue kernel. The 1D launch grid is simple: we have one kernel instance per row o + # f the input matrix + softmax_kernel[(n_rows, )]( + y, + x, + x.stride(0), + y.stride(0), + n_cols, + num_warps=num_warps, + BLOCK_SIZE=BLOCK_SIZE, + ) + return y + +def test_softmax(): + torch.manual_seed(0) + x = torch.randn(1823, 781, device='cpu') + y_triton = softmax(x) + y_torch = torch.softmax(x, axis=1) + assert torch.allclose(y_triton, y_torch), (y_triton, y_torch) diff --git a/test/Conversion/TritonToLinalg/wraparound_side_by_side.mlir b/test/Conversion/TritonToLinalg/wraparound_side_by_side.mlir index ffe59823..a182a934 100644 --- a/test/Conversion/TritonToLinalg/wraparound_side_by_side.mlir +++ b/test/Conversion/TritonToLinalg/wraparound_side_by_side.mlir @@ -56,11 +56,12 @@ module { // CHECK-LABEL: func.func @wrap_side_by_side_masked_loop_01234567 // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32, [[PARAM_10_:%.+]]: i32, [[PARAM_11_:%.+]]: i32, [[PARAM_12_:%.+]]: i32, [[PARAM_13_:%.+]]: i32) { +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index // CHECK-DAG: [[CST_minus_9_dot_900000_:%.+]] = arith.constant -9.900000e+01 : f32 // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index // CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index // CHECK-DAG: [[CST_6_:%.+]] = arith.constant 6 : index -// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i32 +// CHECK-DAG: [[CST_1_1_:%.+]] = arith.constant 1 : i32 // CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : i32 // CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : i32 // CHECK-DAG: [[CST_4_1_:%.+]] = arith.constant 4 : i32 @@ -94,20 +95,19 @@ module { // CHECK-DAG: [[VAR_20_:%.+]] = arith.muli [[VAR_19_]], [[CST_6_]] : index // CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[CST_0_]]{{.}}, sizes: [4, 4], strides: {{.}}[[VAR_12_]], [[VAR_13_]]{{.}} : memref<*xf32> to memref<4x4xf32, strided<[?, ?], offset: ?>> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_21_:%.+]]:6 = scf.for [[VAR_arg14_:%.+]] = [[CST_0_1_]] to [[CST_2_]] step [[CST_1_]] iter_args([[VAR_arg15_:%.+]] = [[VAR_reinterpret_cast_]], [[VAR_arg16_:%.+]] = [[VAR_reinterpret_cast_]]_1, [[VAR_arg17_:%.+]] = [[VAR_17_]], [[VAR_arg18_:%.+]] = [[CST_0_]], [[VAR_arg19_:%.+]] = [[CST_0_]], [[VAR_arg20_:%.+]] = [[VAR_reinterpret_cast_]]_0) -> (memref<4x?xf32, strided<[?, ?], offset: ?>>, memref<4x4xf32, strided<[?, ?], offset: ?>>, index, index, index, memref<4x?xf32, strided<[?, ?], offset: ?>>) : i32 { -// CHECK-DAG: [[VAR_reinterpret_cast_2_:%.+]] = memref.reinterpret_cast [[VAR_arg15_]] to offset: {{.}}[[CST_0_]]{{.}}, sizes: {{.}}[[CST_4_]], [[VAR_10_]]{{.}}, strides: {{.}}[[VAR_0_]], [[VAR_3_]]{{.}} : memref<4x?xf32, strided<[?, ?], offset: ?>> to memref<4x?xf32, strided<[?, ?], offset: ?>> -// CHECK-DAG: [[VAR_reinterpret_cast_3_:%.+]] = memref.reinterpret_cast [[VAR_arg20_]] to offset: {{.}}[[CST_0_]]{{.}}, sizes: {{.}}[[CST_4_]], [[VAR_11_]]{{.}}, strides: {{.}}[[VAR_0_]], [[VAR_3_]]{{.}} : memref<4x?xf32, strided<[?, ?], offset: ?>> to memref<4x?xf32, strided<[?, ?], offset: ?>> +// CHECK-DAG: [[VAR_21_:%.+]]:6 = scf.for [[VAR_arg14_:%.+]] = [[CST_0_1_]] to [[CST_2_]] step [[CST_1_1_]] iter_args([[VAR_arg15_:%.+]] = [[VAR_reinterpret_cast_]], [[VAR_arg16_:%.+]] = [[VAR_reinterpret_cast_]]_1, [[VAR_arg17_:%.+]] = [[VAR_17_]], [[VAR_arg18_:%.+]] = [[CST_0_]], [[VAR_arg19_:%.+]] = [[CST_0_]], [[VAR_arg20_:%.+]] = [[VAR_reinterpret_cast_]]_0) -> (memref<4x?xf32, strided<[?, ?], offset: ?>>, memref<4x4xf32, strided<[?, ?], offset: ?>>, index, index, index, memref<4x?xf32, strided<[?, ?], offset: ?>>) : i32 { // CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4x4xf32> // CHECK: linalg.fill ins([[CST_minus_9_dot_900000_]] : f32) outs([[RES_]] : memref<4x4xf32>) -// CHECK: [[VAR_22_:%.+]] = arith.minsi [[VAR_10_]], [[CST_4_]] : index +// CHECK: [[VAR_dim_:%.+]] = memref.dim [[VAR_arg15_]], [[CST_1_]] : memref<4x?xf32, strided<[?, ?], offset: ?>> +// CHECK: [[VAR_22_:%.+]] = arith.minsi [[VAR_dim_]], [[CST_4_]] : index // CHECK-DAG: [[VAR_23_:%.+]] = arith.subi [[CST_4_]], [[VAR_22_]] : index -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_2_]][0, 0] [2, [[VAR_22_]]{{.}} {{.}}[[VAR_0_]], [[VAR_3_]]{{.}} : memref<4x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[?, ?], offset: ?>> +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_arg15_]][0, 0] [2, [[VAR_22_]]{{.}} [1, 1] : memref<4x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[?, ?], offset: ?>> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_subview_4_:%.+]] = memref.subview [[VAR_reinterpret_cast_3_]][0, 0] [2, [[VAR_23_]]{{.}} {{.}}[[VAR_0_]], [[VAR_3_]]{{.}} : memref<4x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[?, ?], offset: ?>> -// CHECK-DAG: [[VAR_subview_5_:%.+]] = memref.subview [[RES_]][0, 0] [2, [[VAR_22_]]{{.}} {{.}}[[VAR_0_]], [[VAR_3_]]{{.}} : memref<4x4xf32> to memref<2x?xf32, strided<[?, ?]>> -// CHECK-DAG: [[VAR_subview_6_:%.+]] = memref.subview [[RES_]][0, [[VAR_22_]]{{.}} [2, [[VAR_23_]]{{.}} {{.}}[[VAR_0_]], [[VAR_3_]]{{.}} : memref<4x4xf32> to memref<2x?xf32, strided<[?, ?], offset: ?>> -// CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_]]_5 : memref<2x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[?, ?]>> -// CHECK: memref.copy [[VAR_subview_4_]], [[VAR_subview_6_]] : memref<2x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[?, ?], offset: ?>> +// CHECK-DAG: [[VAR_subview_2_:%.+]] = memref.subview [[VAR_arg20_]][0, 0] [2, [[VAR_23_]]{{.}} [1, 1] : memref<4x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[?, ?], offset: ?>> +// CHECK-DAG: [[VAR_subview_3_:%.+]] = memref.subview [[RES_]][0, 0] [2, [[VAR_22_]]{{.}} [1, 1] : memref<4x4xf32> to memref<2x?xf32, strided<[4, 1]>> +// CHECK-DAG: [[VAR_subview_4_:%.+]] = memref.subview [[RES_]][0, [[VAR_22_]]{{.}} [2, [[VAR_23_]]{{.}} [1, 1] : memref<4x4xf32> to memref<2x?xf32, strided<[4, 1], offset: ?>> +// CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_]]_3 : memref<2x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[4, 1]>> +// CHECK: memref.copy [[VAR_subview_2_]], [[VAR_subview_4_]] : memref<2x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[4, 1], offset: ?>> // CHECK: [[VAR_24_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<4x4xf32> // CHECK: memref.tensor_store [[VAR_24_]], [[VAR_arg16_]] : memref<4x4xf32, strided<[?, ?], offset: ?>> // CHECK: [[VAR_25_:%.+]] = arith.index_cast [[VAR_14_]] : i32 to index @@ -118,15 +118,15 @@ module { // CHECK-DAG: [[VAR_30_:%.+]] = arith.addi [[VAR_28_]], [[CST_4_]] : index // CHECK: [[VAR_31_:%.+]] = arith.minsi [[VAR_30_]], [[VAR_18_]] : index // CHECK: [[VAR_32_:%.+]] = arith.subi [[VAR_31_]], [[VAR_28_]] : index -// CHECK-DAG: [[VAR_reinterpret_cast_7_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_27_]]{{.}}, sizes: {{.}}[[CST_4_]], [[VAR_32_]]{{.}}, strides: {{.}}[[VAR_16_]], [[VAR_19_]]{{.}} : memref<*xf32> to memref<4x?xf32, strided<[?, ?], offset: ?>> +// CHECK-DAG: [[VAR_reinterpret_cast_5_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_27_]]{{.}}, sizes: {{.}}[[CST_4_]], [[VAR_32_]]{{.}}, strides: {{.}}[[VAR_16_]], [[VAR_19_]]{{.}} : memref<*xf32> to memref<4x?xf32, strided<[?, ?], offset: ?>> // CHECK-DAG: [[VAR_33_:%.+]] = arith.subi [[CST_4_]], [[VAR_32_]] : index // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_reinterpret_cast_8_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_29_]]{{.}}, sizes: {{.}}[[CST_4_]], [[VAR_33_]]{{.}}, strides: {{.}}[[VAR_16_]], [[VAR_19_]]{{.}} : memref<*xf32> to memref<4x?xf32, strided<[?, ?], offset: ?>> +// CHECK-DAG: [[VAR_reinterpret_cast_6_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_29_]]{{.}}, sizes: {{.}}[[CST_4_]], [[VAR_33_]]{{.}}, strides: {{.}}[[VAR_16_]], [[VAR_19_]]{{.}} : memref<*xf32> to memref<4x?xf32, strided<[?, ?], offset: ?>> // CHECK-DAG: [[VAR_34_:%.+]] = arith.index_cast [[VAR_15_]] : i32 to index // CHECK: [[VAR_35_:%.+]] = arith.addi [[VAR_arg18_]], [[VAR_34_]] : index // CHECK: [[VAR_36_:%.+]] = arith.addi [[VAR_35_]], [[VAR_arg19_]] : index -// CHECK: [[VAR_reinterpret_cast_9_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_36_]]{{.}}, sizes: [4, 4], strides: {{.}}[[VAR_12_]], [[VAR_13_]]{{.}} : memref<*xf32> to memref<4x4xf32, strided<[?, ?], offset: ?>> -// CHECK: scf.yield [[VAR_reinterpret_cast_7_]], [[VAR_reinterpret_cast_9_]], [[VAR_26_]], [[VAR_36_]], [[CST_0_]], [[VAR_reinterpret_cast_8_]] : memref<4x?xf32, strided<[?, ?], offset: ?>>, memref<4x4xf32, strided<[?, ?], offset: ?>>, index, index, index, memref<4x?xf32, strided<[?, ?], offset: ?>> +// CHECK: [[VAR_reinterpret_cast_7_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_36_]]{{.}}, sizes: [4, 4], strides: {{.}}[[VAR_12_]], [[VAR_13_]]{{.}} : memref<*xf32> to memref<4x4xf32, strided<[?, ?], offset: ?>> +// CHECK: scf.yield [[VAR_reinterpret_cast_5_]], [[VAR_reinterpret_cast_7_]], [[VAR_26_]], [[VAR_36_]], [[CST_0_]], [[VAR_reinterpret_cast_6_]] : memref<4x?xf32, strided<[?, ?], offset: ?>>, memref<4x4xf32, strided<[?, ?], offset: ?>>, index, index, index, memref<4x?xf32, strided<[?, ?], offset: ?>> // CHECK: } // CHECK: return // CHECK: } diff --git a/test/Conversion/TritonToLinalg/wraparound_stacked.mlir b/test/Conversion/TritonToLinalg/wraparound_stacked.mlir index 37676ab6..40db0bba 100644 --- a/test/Conversion/TritonToLinalg/wraparound_stacked.mlir +++ b/test/Conversion/TritonToLinalg/wraparound_stacked.mlir @@ -58,15 +58,15 @@ module { // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index // CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index // CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index // CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i32 // CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : i32 -// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : i32 +// CHECK-DAG: [[CST_2_1_:%.+]] = arith.constant 2 : i32 // CHECK-DAG: [[CST_4_1_:%.+]] = arith.constant 4 : i32 -// CHECK-DAG: [[CST_2_1_:%.+]] = arith.constant 2 : index // CHECK-DAG: [[VAR_0_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index // CHECK-DAG: [[VAR_1_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_2_:%.+]] = arith.muli [[VAR_1_]], [[CST_2_1_]] : index +// CHECK-DAG: [[VAR_2_:%.+]] = arith.muli [[VAR_1_]], [[CST_2_]] : index // CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index // CHECK: [[VAR_4_:%.+]] = arith.muli [[VAR_3_]], [[CST_3_]] : index // CHECK: [[VAR_5_:%.+]] = arith.addi [[VAR_2_]], [[VAR_4_]] : index @@ -85,26 +85,25 @@ module { // CHECK-DAG: [[VAR_15_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index // CHECK-DAG: [[VAR_16_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_17_:%.+]] = arith.muli [[VAR_16_]], [[CST_2_1_]] : index +// CHECK-DAG: [[VAR_17_:%.+]] = arith.muli [[VAR_16_]], [[CST_2_]] : index // CHECK-DAG: [[VAR_18_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_19_:%.+]] = arith.muli [[VAR_18_]], [[CST_3_]] : index // CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[CST_0_]]{{.}}, sizes: [4, 4], strides: {{.}}[[VAR_12_]], [[VAR_13_]]{{.}} : memref<*xf32> to memref<4x4xf32, strided<[?, ?], offset: ?>> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_20_:%.+]]:6 = scf.for [[VAR_arg14_:%.+]] = [[CST_0_1_]] to [[CST_2_]] step [[CST_1_]] iter_args([[VAR_arg15_:%.+]] = [[VAR_reinterpret_cast_]], [[VAR_arg16_:%.+]] = [[VAR_reinterpret_cast_]]_1, [[VAR_arg17_:%.+]] = [[VAR_17_]], [[VAR_arg18_:%.+]] = [[CST_0_]], [[VAR_arg19_:%.+]] = [[CST_0_]], [[VAR_arg20_:%.+]] = [[VAR_reinterpret_cast_]]_0) -> (memref>, memref<4x4xf32, strided<[?, ?], offset: ?>>, index, index, index, memref>) : i32 { -// CHECK-DAG: [[VAR_reinterpret_cast_2_:%.+]] = memref.reinterpret_cast [[VAR_arg15_]] to offset: {{.}}[[CST_0_]]{{.}}, sizes: {{.}}[[VAR_10_]], [[CST_4_]]{{.}}, strides: {{.}}[[VAR_1_]], [[VAR_3_]]{{.}} : memref> to memref> -// CHECK-DAG: [[VAR_reinterpret_cast_3_:%.+]] = memref.reinterpret_cast [[VAR_arg20_]] to offset: {{.}}[[CST_0_]]{{.}}, sizes: {{.}}[[VAR_11_]], [[CST_4_]]{{.}}, strides: {{.}}[[VAR_1_]], [[VAR_3_]]{{.}} : memref> to memref> +// CHECK-DAG: [[VAR_20_:%.+]]:6 = scf.for [[VAR_arg14_:%.+]] = [[CST_0_1_]] to [[CST_2_1_]] step [[CST_1_]] iter_args([[VAR_arg15_:%.+]] = [[VAR_reinterpret_cast_]], [[VAR_arg16_:%.+]] = [[VAR_reinterpret_cast_]]_1, [[VAR_arg17_:%.+]] = [[VAR_17_]], [[VAR_arg18_:%.+]] = [[CST_0_]], [[VAR_arg19_:%.+]] = [[CST_0_]], [[VAR_arg20_:%.+]] = [[VAR_reinterpret_cast_]]_0) -> (memref>, memref<4x4xf32, strided<[?, ?], offset: ?>>, index, index, index, memref>) : i32 { // CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4x4xf32> // CHECK: linalg.fill ins([[CST_minus_9_dot_900000_]] : f32) outs([[RES_]] : memref<4x4xf32>) -// CHECK: [[VAR_21_:%.+]] = arith.minsi [[VAR_10_]], [[CST_4_]] : index +// CHECK: [[VAR_dim_:%.+]] = memref.dim [[VAR_arg15_]], [[CST_0_]] : memref> +// CHECK: [[VAR_21_:%.+]] = arith.minsi [[VAR_dim_]], [[CST_4_]] : index // CHECK-DAG: [[VAR_22_:%.+]] = arith.subi [[CST_4_]], [[VAR_21_]] : index -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_2_]][0, 0] {{.}}[[VAR_21_]], 3] {{.}}[[VAR_1_]], [[VAR_3_]]{{.}} : memref> to memref> +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_arg15_]][0, 0] {{.}}[[VAR_21_]], 3] [1, 1] : memref> to memref> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_subview_4_:%.+]] = memref.subview [[VAR_reinterpret_cast_3_]][0, 0] {{.}}[[VAR_22_]], 3] {{.}}[[VAR_1_]], [[VAR_3_]]{{.}} : memref> to memref> -// CHECK-DAG: [[VAR_subview_5_:%.+]] = memref.subview [[RES_]][0, 0] {{.}}[[VAR_21_]], 3] {{.}}[[VAR_1_]], [[VAR_3_]]{{.}} : memref<4x4xf32> to memref> -// CHECK-DAG: [[VAR_subview_6_:%.+]] = memref.subview [[RES_]]{{.}}[[VAR_21_]], 0] {{.}}[[VAR_22_]], 3] {{.}}[[VAR_1_]], [[VAR_3_]]{{.}} : memref<4x4xf32> to memref> -// CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_]]_5 : memref> to memref> -// CHECK: memref.copy [[VAR_subview_4_]], [[VAR_subview_6_]] : memref> to memref> +// CHECK-DAG: [[VAR_subview_2_:%.+]] = memref.subview [[VAR_arg20_]][0, 0] {{.}}[[VAR_22_]], 3] [1, 1] : memref> to memref> +// CHECK-DAG: [[VAR_subview_3_:%.+]] = memref.subview [[RES_]][0, 0] {{.}}[[VAR_21_]], 3] [1, 1] : memref<4x4xf32> to memref> +// CHECK-DAG: [[VAR_subview_4_:%.+]] = memref.subview [[RES_]]{{.}}[[VAR_21_]], 0] {{.}}[[VAR_22_]], 3] [1, 1] : memref<4x4xf32> to memref> +// CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_]]_3 : memref> to memref> +// CHECK: memref.copy [[VAR_subview_2_]], [[VAR_subview_4_]] : memref> to memref> // CHECK: [[VAR_23_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<4x4xf32> // CHECK: memref.tensor_store [[VAR_23_]], [[VAR_arg16_]] : memref<4x4xf32, strided<[?, ?], offset: ?>> // CHECK: [[VAR_24_:%.+]] = arith.index_cast [[VAR_14_]] : i32 to index @@ -115,15 +114,15 @@ module { // CHECK: [[VAR_29_:%.+]] = arith.addi [[VAR_28_]], [[VAR_27_]] : index // CHECK: [[VAR_30_:%.+]] = arith.subi [[VAR_29_]], [[VAR_26_]] : index // CHECK: [[VAR_31_:%.+]] = arith.divsi [[VAR_30_]], [[VAR_16_]] : index -// CHECK-DAG: [[VAR_reinterpret_cast_7_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_26_]]{{.}}, sizes: {{.}}[[VAR_31_]], [[CST_4_]]{{.}}, strides: {{.}}[[VAR_16_]], [[VAR_18_]]{{.}} : memref<*xf32> to memref> +// CHECK-DAG: [[VAR_reinterpret_cast_5_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_26_]]{{.}}, sizes: {{.}}[[VAR_31_]], [[CST_4_]]{{.}}, strides: {{.}}[[VAR_16_]], [[VAR_18_]]{{.}} : memref<*xf32> to memref> // CHECK-DAG: [[VAR_32_:%.+]] = arith.subi [[CST_4_]], [[VAR_31_]] : index // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_reinterpret_cast_8_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_27_]]{{.}}, sizes: {{.}}[[VAR_32_]], [[CST_4_]]{{.}}, strides: {{.}}[[VAR_16_]], [[VAR_18_]]{{.}} : memref<*xf32> to memref> +// CHECK-DAG: [[VAR_reinterpret_cast_6_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_27_]]{{.}}, sizes: {{.}}[[VAR_32_]], [[CST_4_]]{{.}}, strides: {{.}}[[VAR_16_]], [[VAR_18_]]{{.}} : memref<*xf32> to memref> // CHECK-DAG: [[VAR_33_:%.+]] = arith.index_cast [[VAR_14_]] : i32 to index // CHECK: [[VAR_34_:%.+]] = arith.addi [[VAR_arg18_]], [[VAR_33_]] : index // CHECK: [[VAR_35_:%.+]] = arith.addi [[VAR_34_]], [[VAR_arg19_]] : index -// CHECK: [[VAR_reinterpret_cast_9_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_35_]]{{.}}, sizes: [4, 4], strides: {{.}}[[VAR_12_]], [[VAR_13_]]{{.}} : memref<*xf32> to memref<4x4xf32, strided<[?, ?], offset: ?>> -// CHECK: scf.yield [[VAR_reinterpret_cast_7_]], [[VAR_reinterpret_cast_9_]], [[VAR_25_]], [[VAR_35_]], [[CST_0_]], [[VAR_reinterpret_cast_8_]] : memref>, memref<4x4xf32, strided<[?, ?], offset: ?>>, index, index, index, memref> +// CHECK: [[VAR_reinterpret_cast_7_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_35_]]{{.}}, sizes: [4, 4], strides: {{.}}[[VAR_12_]], [[VAR_13_]]{{.}} : memref<*xf32> to memref<4x4xf32, strided<[?, ?], offset: ?>> +// CHECK: scf.yield [[VAR_reinterpret_cast_5_]], [[VAR_reinterpret_cast_7_]], [[VAR_25_]], [[VAR_35_]], [[CST_0_]], [[VAR_reinterpret_cast_6_]] : memref>, memref<4x4xf32, strided<[?, ?], offset: ?>>, index, index, index, memref> // CHECK: } // CHECK: return // CHECK: }