Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LAYOUTS] Use LLs for Hopper whenever we wouldn't use ldmatrix #5235

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ python/triton/language/extra
# Proton
python/triton/profiler

# Pytest
pytest.ini

# Instrumentation
python/triton/instrumentation

Expand Down
4 changes: 2 additions & 2 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ compared to 1*64 when the hasLeadingOffset is false.
int k = (needTrans) ? matShape[0] : matShape[2];
int vec = (order[0] == rank-1) ? k : m;
int mmaStride = (order[0] == rank-1) ? m : k;
int maxPhase = mmaStride / perPhase;
int maxPhase = std::max(mmaStride / perPhase, 1);
return get(context, vec, perPhase, maxPhase, order, CTALayout);
}

Expand All @@ -373,7 +373,7 @@ compared to 1*64 when the hasLeadingOffset is false.
int k = needTrans ? matShape[1] : matShape[2];
int vec = (order[0] == rank-1) ? n : k;
int mmaStride = (order[0] == rank-1) ? k : n;
int maxPhase = mmaStride / perPhase;
int maxPhase = std::max(mmaStride / perPhase, 1);
return get(context, vec, perPhase, maxPhase, order, CTALayout);
}

Expand Down
3 changes: 2 additions & 1 deletion lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,8 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
dotOperandLayout.getOpIdx() == 0 &&
mmaLayout.getWarpsPerCTA()[1] == 1 &&
!cvtNeedsSharedMemory(parentTy, srcTy) &&
(elementTypeSize == 16 || elementTypeSize == 8);
(elementTypeSize == 16 || elementTypeSize == 8) &&
dotOperandLayout.getKWidth() == 32 / elementTypeSize;
return ans;
}

Expand Down
10 changes: 6 additions & 4 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -384,10 +384,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
if (isa<MmaEncodingTrait>(parent) && useLegacyMMAConversion) {
return false;
}
if (auto nvidiaMma = dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
if (nvidiaMma.isAmpere()) {
return true;
}
if (isa<NvidiaMmaEncodingAttr>(parent)) {
return true;
}
if (isa<AMDMfmaEncodingAttr>(parent)) {
return true;
Expand All @@ -408,6 +406,10 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
if (!layoutIsOK(srcTy.getEncoding()) || !layoutIsOK(dstTy.getEncoding())) {
return failure();
}
// FIXME [Dot LL] Remove this once we implement this trick in LLs
if (matchMmaV3AndDotOperandLayout(srcTy, dstTy)) {
return failure();
}

// The following check can be removed when generalized warp shuffle
// conversions are ready:
Expand Down
40 changes: 28 additions & 12 deletions lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,18 +138,34 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {

// FIXME [Dot LL]
// Do for all DotOperandEncodingAttr once we have LLs for all of them
static bool isSupportedDotOpLayout(RankedTensorType type) {
auto layout = type.getEncoding();
auto bitwidth = type.getElementType().getIntOrFloatBitWidth();
if (auto dot = dyn_cast<DotOperandEncodingAttr>(layout)) {
static bool isSupportedDotOpLayout(MemDescType srcTy,
RankedTensorType dstTy) {
auto srcLayout = cast<SharedEncodingAttr>(srcTy.getEncoding());
auto dstLayout = dstTy.getEncoding();
auto bitwidth = dstTy.getElementTypeBitWidth();
auto rank = dstTy.getRank();
if (auto dot = dyn_cast<DotOperandEncodingAttr>(dstLayout)) {
auto vecWidth = 32 / bitwidth;
auto kWidth = dot.getKWidth();
// Use when the SharedToDotOperandMMAv2OrV3 is known to be buggy:
// - kWidth == 8
// - kWidth == 4, bitwidth = 32
auto kOrder = dot.getOpIdx() == 0 ? rank - 1 : rank - 2;
if (auto mma = dyn_cast<NvidiaMmaEncodingAttr>(dot.getParent())) {
auto needTrans = kOrder != srcLayout.getOrder()[0];
auto canUseLdmatrix =
(bitwidth == 16 || (!needTrans)) && (kWidth == vecWidth);
if (mma.isHopper()) {
// I think we should be able to remove this condition, but it's here
// as the legacy ldmatrix path does not support it
canUseLdmatrix &= srcTy.getElementTypeBitWidth() * kWidth == 32;
}
// If we remove this one, ldmatrix will IMA. It can probably be relaxed
// though
canUseLdmatrix &=
srcTy.getShape()[0] >= 8 && srcTy.getShape()[1] >= 4 * kWidth;
// To be removed in https://github.com/triton-lang/triton/pull/5154
bool legacyLoweringIsBuggy =
kWidth >= 8 || (kWidth == 4 && bitwidth == 32);
return legacyLoweringIsBuggy && mma.isAmpere();
(kWidth >= 8 || (kWidth == 4 && bitwidth == 32)) && mma.isAmpere();
return (mma.isHopper() && !canUseLdmatrix) ||
(mma.isAmpere() && legacyLoweringIsBuggy);
}
if (isa<AMDMfmaEncodingAttr>(dot.getParent()))
return true;
Expand All @@ -162,12 +178,12 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
ConversionPatternRewriter &rewriter) const override {
MemDescType srcTy = op.getSrc().getType();
RankedTensorType dstTy = op.getType();
Attribute srcLayout = srcTy.getEncoding();
auto srcLayout = cast<SharedEncodingAttr>(srcTy.getEncoding());
Attribute dstLayout = dstTy.getEncoding();
if (isa<SharedEncodingAttr>(srcLayout) &&
(isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr,
LinearEncodingAttr>(dstLayout) ||
isSupportedDotOpLayout(dstTy))) {
isSupportedDotOpLayout(srcTy, dstTy))) {
return lowerSharedToDistributed(op, adaptor, getTypeConverter(),
rewriter);
}
Expand Down Expand Up @@ -206,7 +222,7 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
auto dstTy = op.getResult().getType();
auto dstShape = dstTy.getShape();
auto srcSharedLayout = cast<SharedEncodingAttr>(srcTy.getEncoding());
assert((dstShape.size() <= 2 || isSupportedDotOpLayout(dstTy)) &&
assert((dstShape.size() <= 2 || isSupportedDotOpLayout(srcTy, dstTy)) &&
"Unexpected rank of ConvertLayout(shared->distributed)");

auto smemObj = LLVM::getSharedMemoryObjectFromStruct(
Expand Down
3 changes: 3 additions & 0 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5250,6 +5250,9 @@ def kernel(Out):
# TODO: backend should be tested separately

layouts = [
MmaLayout([3, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 32, 16]),
DotOperandLayout(parent=MmaLayout([3, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 32, 16]), op_idx=0, k_width=2),
DotOperandLayout(parent=MmaLayout([3, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 32, 16]), op_idx=0, k_width=1),
BlockedLayout([1, 16], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
BlockedLayout([1, 8], [2, THREADS_PER_WARP // 2], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -459,9 +459,6 @@ void mlir::triton::NVIDIA::populateConvertLayoutOpToLLVMPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
// For now give ConvertLayoutOpConversion higher benefit, I can split before
// merging
//
// TODO(jlebar): lowerDistributedToDistributed does not get hit in any
// testcases. Is this dead code? Does the benefit need to be increased?
patterns.add<ConvertLayoutOpConversion>(typeConverter, targetInfo, benefit);
// Same default benefit
patterns.add<LocalLoadOpConversion>(typeConverter, benefit);
Expand Down
Loading