Skip to content

Commit

Permalink
Make bitcast work for tensors of pointers
Browse files Browse the repository at this point in the history
  • Loading branch information
parsifal-47 committed Nov 29, 2024
1 parent 3fe82cb commit eaaf7c5
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 3 deletions.
2 changes: 2 additions & 0 deletions include/triton-shared/AnalysisStructured/PtrAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,8 @@ class PtrAnalysis {
// PtrState for knownPtrs.
LogicalResult rewriteAddptrOp(triton::AddPtrOp op);

LogicalResult rewriteBitcastOp(triton::BitcastOp op);

LogicalResult rewriteMakeTensorPtrOp(triton::MakeTensorPtrOp op);

LogicalResult rewriteAdvanceOp(triton::AdvanceOp op);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def TTS_MakeTensorPtrOp
: TTS_Op<"make_tptr", [AttrSizedOperandSegments, Pure]> {
let summary = "create a pointer that points to a tensor in memory";

// base: Base pointer used to contruct the tensor of pointers or pointer to tensor.
// base: Base pointer used to construct the tensor of pointers or pointer to tensor.
// sizes: Size of the data being loaded or stored.
// strides: The strides of the parent tensor, which means how much to increase the pointer
// by when moving by 1 element in a specific axis.
Expand Down Expand Up @@ -120,6 +120,19 @@ def TTS_MakeTensorPtrOp
//let hasCanonicalizer = 1;
}

def TTS_CastTensorPtrOp : TTS_Op<"cast_tptr", [SameOperandsAndResultShape,
Pure]> {
let summary = "Cast between types tensor pointers";

let arguments = (ins TT_PtrLike:$src);

let results = (outs TT_PtrLike:$result);

let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)";

// TODO: Add verifier
}

// SameVariadicResultSize
// AttrSizedResultSegments
def TTS_GetStructuredStateOp : TTS_Op<"get_structured_state", [AttrSizedResultSegments, Pure]> {
Expand Down
27 changes: 27 additions & 0 deletions lib/AnalysisStructured/PtrAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,26 @@ LogicalResult PtrAnalysis::rewriteAddptrOp(triton::AddPtrOp op) {
return success();
}

LogicalResult PtrAnalysis::rewriteBitcastOp(triton::BitcastOp op) {
Type resultType = op.getType();

if (auto resultTensorType = dyn_cast<RankedTensorType>(resultType)) {
Type elementType = resultTensorType.getElementType();
if (auto pointerType = dyn_cast<triton::PointerType>(elementType)) {
// arith::bitcast cannot handle pointers,
// we need to handle this clause separately
OpBuilder builder(op);
auto cast = builder.create<mlir::tts::CastTensorPtrOp>(op.getLoc(), resultType, ptrMap.lookupOrNull(op.getSrc()));
op->replaceAllUsesWith(cast);
op->erase();
ptrMap.map(cast.getResult(), cast.getResult());
return success();
}
}

return failure();
}

LogicalResult PtrAnalysis::rewriteMakeTensorPtrOp(triton::MakeTensorPtrOp op) {
OpBuilder builder(op);

Expand Down Expand Up @@ -1297,6 +1317,13 @@ LogicalResult PtrAnalysis::rewriteOp(Operation *rootOp, bool useUnsafeMask) {
}
return WalkResult::advance();
})
.Case<triton::BitcastOp>([&](auto addptr) {
if (rewriteBitcastOp(addptr).failed()) {
// failure means incompatible arguments
WalkResult::skip();
}
return WalkResult::advance();
})
.Case<triton::MakeTensorPtrOp>([&](auto maketptr) {
if (rewriteMakeTensorPtrOp(maketptr).failed()) {
maketptr->emitRemark(
Expand Down
23 changes: 22 additions & 1 deletion lib/Conversion/StructuredToMemref/StructuredToMemref.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -848,12 +848,33 @@ struct UnrealizedCastConverter
}
};

struct CastTensorPtrConverter
: public OpConversionPattern<tts::CastTensorPtrOp> {
using OpConversionPattern<tts::CastTensorPtrOp>::OpConversionPattern;
CastTensorPtrConverter(TypeConverter &typeConverter, MLIRContext *context)
: OpConversionPattern<tts::CastTensorPtrOp>(typeConverter, context) {}

LogicalResult
matchAndRewrite(tts::CastTensorPtrOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value src = op.getSrc();
auto srcType = cast<RankedTensorType>(src.getType());
auto dstType = cast<RankedTensorType>(op.getType());
auto pointerType = cast<triton::PointerType>(dstType.getElementType());
auto newType = MemRefType::get(srcType.getShape(), pointerType.getPointeeType());

auto unrealizedCast = rewriter.create<UnrealizedConversionCastOp>(op.getLoc(), newType, adaptor.getOperands());
rewriter.replaceOp(op, unrealizedCast);
return success();
}
};

} // namespace

void mlir::triton::populateStructuredToMemrefConversionPatterns(
RewritePatternSet &patterns, TypeConverter &typeConverter) {
patterns.add<UnrealizedCastConverter>(typeConverter, patterns.getContext());
patterns.add<MakeTensorPtrConverter, LoadConverter, StoreConverter,
ScalarLoadConverter, ScalarStoreConverter>(
ScalarLoadConverter, ScalarStoreConverter, CastTensorPtrConverter>(
patterns.getContext());
}
1 change: 0 additions & 1 deletion python/examples/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def device(request):
"test_tensor_atomic_rmw_block",
"test_nested_if_else_return",
"test_ptx_cast",
"test_compare_op",
"test_maxnreg",
"test_join",
"test_join_scalars",
Expand Down
44 changes: 44 additions & 0 deletions test/Conversion/TritonToStructured/cast_ptr.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// RUN: triton-shared-opt --triton-to-structured %s | FileCheck %s

module {
tt.func public @kernel(%arg0: !tt.ptr<i1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i8> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<i32> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
%1 = tt.splat %arg1 : !tt.ptr<i8> -> tensor<128x!tt.ptr<i8>>
%2 = tt.addptr %1, %0 : tensor<128x!tt.ptr<i8>>, tensor<128xi32>
%3 = tt.load %2 : tensor<128x!tt.ptr<i8>>
%4 = tt.splat %arg2 : !tt.ptr<i32> -> tensor<128x!tt.ptr<i32>>
%5 = tt.addptr %4, %0 : tensor<128x!tt.ptr<i32>>, tensor<128xi32>
%6 = tt.load %5 : tensor<128x!tt.ptr<i32>>
%7 = arith.extsi %3 : tensor<128xi8> to tensor<128xi32>
%8 = arith.cmpi eq, %7, %6 : tensor<128xi32>
%9 = tt.splat %arg0 : !tt.ptr<i1> -> tensor<128x!tt.ptr<i1>>
%10 = tt.addptr %9, %0 : tensor<128x!tt.ptr<i1>>, tensor<128xi32>
%11 = tt.bitcast %10 : tensor<128x!tt.ptr<i1>> -> tensor<128x!tt.ptr<i8>>
%12 = arith.extui %8 : tensor<128xi1> to tensor<128xi8>
tt.store %11, %12 : tensor<128x!tt.ptr<i8>>
tt.return
}
}

// CHECK: module {
// CHECK: tt.func public @kernel(%arg0: !tt.ptr<i1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i8> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<i32> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
// CHECK: [[VAR_0:%.+]] = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
// CHECK: [[VAR_1:%.+]] = tt.splat %arg1 : !tt.ptr<i8> -> tensor<128x!tt.ptr<i8>>
// CHECK: [[VAR_2:%.+]] = tts.make_tptr %arg1 to sizes: [128], strides: [1], offsets: [0], shape: [0], order: [] : <i8> to tensor<128x!tt.ptr<i8>>
// CHECK: [[VAR_3:%.+]] = tt.addptr [[VAR_1]], [[VAR_0]] : tensor<128x!tt.ptr<i8>>, tensor<128xi32>
// CHECK: [[VAR_4:%.+]] = "tts.load"([[VAR_2]]) <{operandSegmentSizes = array<i32: 1, 0, 0>, static_mask_dims = array<i64>}> : (tensor<128x!tt.ptr<i8>>) -> tensor<128xi8>
// CHECK: [[VAR_5:%.+]] = tt.splat %arg2 : !tt.ptr<i32> -> tensor<128x!tt.ptr<i32>>
// CHECK: [[VAR_6:%.+]] = tts.make_tptr %arg2 to sizes: [128], strides: [1], offsets: [0], shape: [0], order: [] : <i32> to tensor<128x!tt.ptr<i32>>
// CHECK: [[VAR_7:%.+]] = tt.addptr [[VAR_5]], [[VAR_0]] : tensor<128x!tt.ptr<i32>>, tensor<128xi32>
// CHECK: [[VAR_8:%.+]] = "tts.load"([[VAR_6]]) <{operandSegmentSizes = array<i32: 1, 0, 0>, static_mask_dims = array<i64>}> : (tensor<128x!tt.ptr<i32>>) -> tensor<128xi32>
// CHECK: [[VAR_9:%.+]] = arith.extsi [[VAR_4]] : tensor<128xi8> to tensor<128xi32>
// CHECK: [[VAR_10:%.+]] = arith.cmpi eq, [[VAR_9]], [[VAR_8]] : tensor<128xi32>
// CHECK: [[VAR_11:%.+]] = tt.splat %arg0 : !tt.ptr<i1> -> tensor<128x!tt.ptr<i1>>
// CHECK: [[VAR_12:%.+]] = tts.make_tptr %arg0 to sizes: [128], strides: [1], offsets: [0], shape: [0], order: [] : <i1> to tensor<128x!tt.ptr<i1>>
// CHECK: [[VAR_13:%.+]] = tt.addptr [[VAR_11]], [[VAR_0]] : tensor<128x!tt.ptr<i1>>, tensor<128xi32>
// CHECK: [[VAR_14:%.+]] = tts.cast_tptr [[VAR_12]] : tensor<128x!tt.ptr<i1>> -> tensor<128x!tt.ptr<i8>>
// CHECK: [[VAR_15:%.+]] = arith.extui [[VAR_10]] : tensor<128xi1> to tensor<128xi8>
// CHECK: "tts.store"([[VAR_14]], [[VAR_15]]) <{static_mask_dims = array<i64>}> : (tensor<128x!tt.ptr<i8>>, tensor<128xi8>) -> ()
// CHECK: tt.return
// CHECK: }
// CHECK: }

0 comments on commit eaaf7c5

Please sign in to comment.