Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Triton] Add tl.gather with a naive codegen implementation #5262

Merged
merged 20 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,19 @@ class ScanLoweringHelper {
SmallVector<Type> srcElementTypes;
};

// Helper class for lowering `tt.gather` operations. This class shares lowering
// logic between shared memory allocation and LLVM codegen.
class GatherLoweringHelper {
public:
GatherLoweringHelper(triton::GatherOp gatherOp);

// Get the shared memory scratch size required by this op.
unsigned getScratchSizeInBytes();

private:
triton::GatherOp gatherOp;
};

// Decomposes a reshape into simpler pieces.
//
// As an example, suppose we have a reshape from [4,4,4] to [2,2,8,2].
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ void populateScanOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
const TargetInfoBase &targetInfo,
PatternBenefit benefit);
void populateGatherOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
const TargetInfoBase &targetInfo,
PatternBenefit benefit);

void populateConvertLayoutOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
const TargetInfoBase &targetInfo,
Expand Down
5 changes: 5 additions & 0 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -1125,6 +1125,11 @@ emitBaseIndexForLayout(Location loc, RewriterBase &rewriter,

// Emit indices calculation within each ConversionPattern, and returns a
// [elemsPerThread X rank] index matrix.
//
// For example, for a thread a owns `elemsPerThread` elements of a tensor with
// type `type` and layout `layout`, the result will contain `elemsPerThread`
// vectors. Each vector contains the SSA values of the indices required to
// access the corresponding element, starting from the inner dimension.
SmallVector<SmallVector<Value>>
emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
Attribute layout, RankedTensorType type, bool withCTAOffset);
Expand Down
26 changes: 26 additions & 0 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,32 @@ def TT_HistogramOp : TT_Op<"histogram", [Pure]> {
}];
}

//
// Gather Op
//
def TT_GatherOp : TT_Op<"gather", [Pure,
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "local gather operation";
let description = [{
Gather elements from the input tensor using the indices tensor along a
single specified axis. The output tensor has the same shape as the indices
tensor. The input and indices tensors must have the same number of
dimension, and each dimension of the indices tensor that is not the gather
dimension cannot be greater than the corresponding dimension in the input
tensor.
}];

let arguments = (ins TT_Tensor:$src, TT_IntTensor:$indices, I32Attr:$axis);
let results = (outs TT_Tensor:$result);

let assemblyFormat = [{
$src `[` $indices `]` attr-dict `:`
functional-type(operands, results)
}];

let hasVerifier = 1;
}

//
// Print Op
//
Expand Down
4 changes: 4 additions & 0 deletions lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op) {
ScanLoweringHelper helper(scanOp);
return helper.getScratchSizeInBytes();
}
if (auto gatherOp = dyn_cast<GatherOp>(op)) {
GatherLoweringHelper helper(gatherOp);
return helper.getScratchSizeInBytes();
}
if (auto histogram = dyn_cast<HistogramOp>(op)) {
auto dstTy = histogram.getType();
int threadsPerWarp = gpu::TritonGPUDialect::getThreadsPerWarp(
Expand Down
11 changes: 11 additions & 0 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,17 @@ unsigned ScanLoweringHelper::getAxisBlockStride() {
llvm_unreachable("Axis not found in order");
}

GatherLoweringHelper::GatherLoweringHelper(triton::GatherOp gatherOp)
: gatherOp(gatherOp) {}

unsigned GatherLoweringHelper::getScratchSizeInBytes() {
// For now, lower the gather op by writing the source tensor to shared memory.
// TODO(jeff): Leverage locality to avoid using scratch space when possible.
RankedTensorType srcType = gatherOp.getSrc().getType();
return product(srcType.getShape()) *
ceil<unsigned>(srcType.getElementTypeBitWidth(), 8);
}

unsigned getNumScratchElements(ArrayRef<unsigned> shape) {
if (shape.empty())
return 0;
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TritonGPUToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ add_triton_library(TritonGPUToLLVM
AllocateSharedMemory.cpp
ReduceOpToLLVM.cpp
ScanOpToLLVM.cpp
GatherOpToLLVM.cpp
ConvertLayoutOpToLLVM.cpp
ControlFlowOpToLLVM.cpp
FuncOpToLLVM.cpp
Expand Down
109 changes: 109 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"

using namespace mlir;
using namespace mlir::triton;

namespace {
class GatherOpConversion : public ConvertOpToLLVMPattern<GatherOp> {
public:
GatherOpConversion(LLVMTypeConverter &typeConverter,
const TargetInfoBase &targetInfo, PatternBenefit benefit)
: ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) {
}

LogicalResult
matchAndRewrite(GatherOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;

private:
const TargetInfoBase &targetInfo;
};

LogicalResult
GatherOpConversion::matchAndRewrite(GatherOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
Mogball marked this conversation as resolved.
Show resolved Hide resolved
RankedTensorType srcType = op.getSrc().getType();

// Compute the src subtensor shape owned by this CTA.
SmallVector<unsigned> srcShapePerCTA =
convertType<unsigned>(triton::gpu::getShapePerCTA(srcType));

// Grab the src values in this thread.
SmallVector<Value> srcValues =
unpackLLElements(loc, adaptor.getSrc(), rewriter);

// Emit the indices of the src values owned by this thread.
SmallVector<SmallVector<Value>> srcIndices =
emitIndices(loc, rewriter, targetInfo, srcType.getEncoding(),
op.getSrc().getType(), /*withCTAOffset=*/true);

// Store the src values owned by the thread into their respective location in
// the scratch memory.
assert(srcValues.size() == srcIndices.size());

// Get the base pointer to the scratch memory.
Value smemBase = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op);

// For each src element owned by the thread, index into the scratch memory and
// then store it.
Type elemType = getTypeConverter()->convertType(srcType.getElementType());
for (auto [value, indices] : llvm::zip(srcValues, srcIndices)) {
// Convert the index at each dim into a single offset given the shape of the
// tensor.
Value offset = LLVM::linearize(rewriter, loc, indices, srcShapePerCTA);
// Emit the offset into the shared memory and then store the value.
Value ptr = gep(smemBase.getType(), elemType, smemBase, offset);
store(value, ptr);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be left as a TODO, but we should do a masked store with getRedundantDataMask

Value getRedundantDataMask(ModuleOp moduleOp, Type valueTy,

(TBH I think there are quite a few places we need to cleanup redundant operations)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think a masked store is better, in general there isn't performance penalty for storing multiple times to the same address in shared memory. Using a store without mask allows better code generation in general

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in general there isn't performance penalty for storing multiple times to the same address in shared memory.

would I be right to say this is only true if there are no bank conflicts?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, as far as I know it is orthogonal to bank conflicts

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused. Are you saying that @p st.shared where p is false on all threads in the warp can still take multiple cycles?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah no that's not what I meant. I meant when doing a st.shared with duplicated addresses there isn't a penalty for storing multiple time. The HW will pick one of them and store only once.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I agree that on a single instruction level redundancy is okay. I guess that could be expressed by ignoring redundancy within a warp.

The problematic cases are:

  • within a thread i.e. calling st.shared more times than necessary.
  • between warps i.e. multiple warps calling st.shared to transfer the same data.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah fair point.
I do wonder if it is always better to predicate as redundant data should not be the common case and there are some downsides to using predicates. The main downside is that it prevents the backend from using a larger element bitwidth.
Might be worth measuring at some point.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough, I'll note that for thread level redundancy we shouldn't actually need a masked load though. We could just not emit the instructions.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

true. Also another problem I remember running into with predicates (although not this exact case) is that it tends to throw off ptxas a lot in term of scheduling and liveranges. This applies more to load than to store (the liveness problem doesn't exist with loads, the scheduling I'm not sure) but that's one thing to keep in mind.

}

// Synchronize the whole CTA.
// TODO(jeff): Should we teach Membar that gather synchronizes?
Mogball marked this conversation as resolved.
Show resolved Hide resolved
barrier();

// Grab the index values owned by this thread.
SmallVector<Value> idxValues =
unpackLLElements(loc, adaptor.getIndices(), rewriter);

// I = LL(pid)
Mogball marked this conversation as resolved.
Show resolved Hide resolved
// idx = indices[I]
// I_gather = [I[d] if d != axis else idx for d in range(len(I))]
// out[I] = src[I_gather]
RankedTensorType dstType = op.getType();
SmallVector<SmallVector<Value>> dstIndices =
emitIndices(loc, rewriter, targetInfo, dstType.getEncoding(), dstType,
/*withCTAOffset=*/true);

unsigned idxWidth = op.getIndices().getType().getElementTypeBitWidth();
unsigned axis = op.getAxis();
SmallVector<Value> results(dstIndices.size());
for (auto [i, idx, indices] : llvm::enumerate(idxValues, dstIndices)) {
// The LL index computations are performed with 32 bit integers. If the
// indices are something else, cast them to i32.
if (idxWidth > 32) {
idx = trunc(i32_ty, idx);
} else if (idxWidth < 32) {
// Negative indices don't make sense, so zero-extend.
idx = zext(i32_ty, idx);
}
indices[axis] = idx;
Value offset = LLVM::linearize(rewriter, loc, indices, srcShapePerCTA);
Value ptr = gep(smemBase.getType(), elemType, smemBase, offset);
results[i] = load(elemType, ptr);
}

Value packed =
packLLElements(loc, getTypeConverter(), results, rewriter, dstType);
rewriter.replaceOp(op, packed);
return success();
}

} // namespace

void triton::populateGatherOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
const TargetInfoBase &targetInfo,
PatternBenefit benefit) {
patterns.insert<GatherOpConversion>(typeConverter, targetInfo, benefit);
}
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,7 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
GenericOpPattern<triton::MakeRangeOp>, TritonExpandDimsPattern,
TritonTransPattern, TritonDotPattern, GenericOpPattern<triton::LoadOp>,
GenericOpPattern<triton::StoreOp>, GenericOpPattern<triton::HistogramOp>,
GenericOpPattern<triton::GatherOp>,
GenericOpPattern<triton::ExternElementwiseOp>,
GenericOpPattern<triton::PrintOp>, GenericOpPattern<triton::AssertOp>,
GenericOpPattern<triton::AtomicCASOp>,
Expand Down
49 changes: 49 additions & 0 deletions lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1073,6 +1073,55 @@ Speculation::Speculatability ExternElementwiseOp::getSpeculatability() {
return Speculation::NotSpeculatable;
}

// -- GatherOp --
LogicalResult GatherOp::verify() {
RankedTensorType indicesTy = getIndices().getType();
RankedTensorType srcTy = getSrc().getType();
RankedTensorType resTy = getResult().getType();

if (indicesTy.getShape() != resTy.getShape()) {
Jokeren marked this conversation as resolved.
Show resolved Hide resolved
return emitOpError("indices and output shapes must match");
}
if (indicesTy.getEncoding() != resTy.getEncoding()) {
Jokeren marked this conversation as resolved.
Show resolved Hide resolved
return emitOpError("indices and output encodings must match");
}
if (srcTy.getElementType() != resTy.getElementType()) {
return emitOpError("input and output element types must match");
}
if (srcTy.getRank() != indicesTy.getRank()) {
return emitOpError("input and indices ranks must match");
}
if (getAxis() >= srcTy.getRank()) {
return emitOpError("gather dimension must be less than the input rank");
}
for (int dim = 0; dim < indicesTy.getRank(); ++dim) {
if (dim == getAxis())
continue;
if (indicesTy.getShape()[dim] > srcTy.getShape()[dim]) {
return emitOpError("indices dimension ")
<< dim
<< " cannot be greater than the corresponding input dimension";
}
}

return success();
}

LogicalResult GatherOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
GatherOpAdaptor adaptor(operands, attributes, properties, regions);
auto indicesType = cast<RankedTensorType>(adaptor.getIndices().getType());
auto srcType = cast<RankedTensorType>(adaptor.getSrc().getType());

// Shape and encoding of the indices with the element type of the src.
inferredReturnTypes.push_back(
RankedTensorType::get(indicesType.getShape(), srcType.getElementType(),
indicesType.getEncoding()));
return success();
}

// -- ExperimentalTensormapCreateOp --
LogicalResult ExperimentalTensormapCreateOp::verify() {
auto rank = getBoxDim().size();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ SmallVector<Value> LayoutPropagation::propagateToUsers(Value value,
setEncoding(user->getResults(), info, changed, user);
continue;
}
// TODO(jeff): Propagate tt.gather indices layout to dst.
}
return changed;
}
Expand Down Expand Up @@ -709,6 +710,7 @@ Operation *LayoutPropagation::rewriteOp(Operation *op) {
}
return newOp;
}
// TODO(jeff): Handle tt.gather once it supports layout propagation.
llvm::report_fatal_error("unexpected op in rewrite");
return nullptr;
}
Expand Down
3 changes: 3 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,8 @@ std::optional<Attribute> inferSrcEncoding(Operation *op, Attribute encoding) {
return inferSrcEncoding(trans, encoding);
if (auto reshape = dyn_cast<triton::ReshapeOp>(op))
return inferSrcEncoding(reshape, encoding);
// TODO(jeff): Handle progagating tt.gather indices -> dst layout.
// This requires updating the API to specify the exact operands and results.

return std::nullopt;
}
Expand Down Expand Up @@ -499,6 +501,7 @@ std::optional<Attribute> inferDstEncoding(Operation *op, Attribute encoding) {
return inferDstEncoding(trans, encoding);
if (auto reshape = dyn_cast<triton::ReshapeOp>(op))
return inferDstEncoding(reshape, encoding);
// TODO(jeff): Handle progagating tt.gather indices -> dst layout.

return std::nullopt;
}
Expand Down
3 changes: 3 additions & 0 deletions python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1625,6 +1625,9 @@ void init_triton_ir(py::module &&m) {
IntegerType::get(operand.getContext(), 32)),
operand);
})
.def("create_gather",
[](TritonOpBuilder &self, Value src, Value indices, int axis)
-> Value { return self.create<GatherOp>(src, indices, axis); })
// Force GPU barrier
.def("create_barrier",
[](TritonOpBuilder &self) { self.create<mlir::gpu::BarrierOp>(); })
Expand Down
40 changes: 40 additions & 0 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6087,3 +6087,43 @@ def kernel(In, Out, #
perm[0], perm[1], perm[2], perm[3], perm[4], red_dims[0], red_dims[1], red_dims[2])

assert torch.all(ref == result)


@pytest.mark.parametrize("src_shape, indices_shape, axis", [
([4, 4], [8, 2], 0),
([128, 64], [256, 32], 0),
([128, 64], [128, 128], 1),
])
def test_gather(src_shape, indices_shape, axis):

@triton.jit
def gather_kernel(src_ptr, idx_ptr, out_ptr, axis: tl.constexpr, src_dim0: tl.constexpr, src_dim1: tl.constexpr,
src_stride0: tl.constexpr, src_stride1: tl.constexpr, idx_dim0: tl.constexpr,
idx_dim1: tl.constexpr, idx_stride0: tl.constexpr, idx_stride1: tl.constexpr,
out_dim0: tl.constexpr, out_dim1: tl.constexpr, out_stride0: tl.constexpr,
out_stride1: tl.constexpr):
src_offs = (tl.arange(0, src_dim0)[:, None] * src_stride0 + tl.arange(0, src_dim1)[None, :] * src_stride1)
src = tl.load(src_ptr + src_offs)

idx_offs = (tl.arange(0, idx_dim0)[:, None] * idx_stride0 + tl.arange(0, idx_dim1)[None, :] * idx_stride1)
idx = tl.load(idx_ptr + idx_offs)

out = tl.gather(src, idx, axis)

out_offs = (tl.arange(0, out_dim0)[:, None] * out_stride0 + tl.arange(0, out_dim1)[None, :] * out_stride1)
tl.store(out_ptr + out_offs, out)

def triton_gather(src: torch.Tensor, axis: int, indices: torch.Tensor):
output = torch.empty(indices.shape, dtype=src.dtype, device=src.device)

gather_kernel[(1, )](src, indices, output, axis, src.shape[0], src.shape[1],
src.stride(0), src.stride(1), indices.shape[0], indices.shape[1], indices.stride(0),
indices.stride(1), output.shape[0], output.shape[1], output.stride(0), output.stride(1))

return output

src = torch.randn(src_shape, device='cuda')
indices = torch.randint(0, src.shape[axis], indices_shape, device='cuda')
ref = torch.gather(src, axis, indices)
result = triton_gather(src, axis, indices)
assert torch.all(ref == result)
Mogball marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 2 additions & 0 deletions python/triton/language/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
float8e5b16,
full,
function_type,
gather,
histogram,
inline_asm_elementwise,
int1,
Expand Down Expand Up @@ -188,6 +189,7 @@
"fma",
"full",
"function_type",
"gather",
"histogram",
"inline_asm_elementwise",
"interleave",
Expand Down
Loading
Loading