diff --git a/include/triton-shared/AnalysisStructured/PtrAnalysis.h b/include/triton-shared/AnalysisStructured/PtrAnalysis.h index 3e8b6b60..c7f0a057 100644 --- a/include/triton-shared/AnalysisStructured/PtrAnalysis.h +++ b/include/triton-shared/AnalysisStructured/PtrAnalysis.h @@ -242,6 +242,8 @@ class PtrAnalysis { // PtrState for knownPtrs. LogicalResult rewriteAddptrOp(triton::AddPtrOp op); + LogicalResult rewriteBitcastOp(triton::BitcastOp op); + LogicalResult rewriteMakeTensorPtrOp(triton::MakeTensorPtrOp op); LogicalResult rewriteAdvanceOp(triton::AdvanceOp op); diff --git a/include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.td b/include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.td index c0f89bfc..224d0dd7 100644 --- a/include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.td +++ b/include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.td @@ -34,7 +34,7 @@ def TTS_MakeTensorPtrOp : TTS_Op<"make_tptr", [AttrSizedOperandSegments, Pure]> { let summary = "create a pointer that points to a tensor in memory"; - // base: Base pointer used to contruct the tensor of pointers or pointer to tensor. + // base: Base pointer used to construct the tensor of pointers or pointer to tensor. // sizes: Size of the data being loaded or stored. // strides: The strides of the parent tensor, which means how much to increase the pointer // by when moving by 1 element in a specific axis. @@ -120,6 +120,19 @@ def TTS_MakeTensorPtrOp //let hasCanonicalizer = 1; } +def TTS_CastTensorPtrOp : TTS_Op<"cast_tptr", [SameOperandsAndResultShape, + Pure]> { + let summary = "Cast between types tensor pointers"; + + let arguments = (ins TT_PtrLike:$src); + + let results = (outs TT_PtrLike:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; + + // TODO: Add verifier +} + // SameVariadicResultSize // AttrSizedResultSegments def TTS_GetStructuredStateOp : TTS_Op<"get_structured_state", [AttrSizedResultSegments, Pure]> { diff --git a/lib/AnalysisStructured/PtrAnalysis.cpp b/lib/AnalysisStructured/PtrAnalysis.cpp index bff6b1b7..85070c13 100644 --- a/lib/AnalysisStructured/PtrAnalysis.cpp +++ b/lib/AnalysisStructured/PtrAnalysis.cpp @@ -804,6 +804,26 @@ LogicalResult PtrAnalysis::rewriteAddptrOp(triton::AddPtrOp op) { return success(); } +LogicalResult PtrAnalysis::rewriteBitcastOp(triton::BitcastOp op) { + Type resultType = op.getType(); + + if (auto resultTensorType = dyn_cast(resultType)) { + Type elementType = resultTensorType.getElementType(); + if (auto pointerType = dyn_cast(elementType)) { + // arith::bitcast cannot handle pointers, + // we need to handle this clause separately + OpBuilder builder(op); + auto cast = builder.create(op.getLoc(), resultType, ptrMap.lookupOrNull(op.getSrc())); + op->replaceAllUsesWith(cast); + op->erase(); + ptrMap.map(cast.getResult(), cast.getResult()); + return success(); + } + } + + return failure(); +} + LogicalResult PtrAnalysis::rewriteMakeTensorPtrOp(triton::MakeTensorPtrOp op) { OpBuilder builder(op); @@ -1297,6 +1317,13 @@ LogicalResult PtrAnalysis::rewriteOp(Operation *rootOp, bool useUnsafeMask) { } return WalkResult::advance(); }) + .Case([&](auto addptr) { + if (rewriteBitcastOp(addptr).failed()) { + // failure means incompatible arguments + WalkResult::skip(); + } + return WalkResult::advance(); + }) .Case([&](auto maketptr) { if (rewriteMakeTensorPtrOp(maketptr).failed()) { maketptr->emitRemark( diff --git a/lib/Conversion/StructuredToMemref/StructuredToMemref.cpp b/lib/Conversion/StructuredToMemref/StructuredToMemref.cpp index b5e1165a..c80e4eb8 100644 --- a/lib/Conversion/StructuredToMemref/StructuredToMemref.cpp +++ b/lib/Conversion/StructuredToMemref/StructuredToMemref.cpp @@ -848,12 +848,33 @@ struct UnrealizedCastConverter } }; +struct CastTensorPtrConverter + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + CastTensorPtrConverter(TypeConverter &typeConverter, MLIRContext *context) + : OpConversionPattern(typeConverter, context) {} + + LogicalResult + matchAndRewrite(tts::CastTensorPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src = op.getSrc(); + auto srcType = cast(src.getType()); + auto dstType = cast(op.getType()); + auto pointerType = cast(dstType.getElementType()); + auto newType = MemRefType::get(srcType.getShape(), pointerType.getPointeeType()); + + auto unrealizedCast = rewriter.create(op.getLoc(), newType, adaptor.getOperands()); + rewriter.replaceOp(op, unrealizedCast); + return success(); + } +}; + } // namespace void mlir::triton::populateStructuredToMemrefConversionPatterns( RewritePatternSet &patterns, TypeConverter &typeConverter) { patterns.add(typeConverter, patterns.getContext()); patterns.add( + ScalarLoadConverter, ScalarStoreConverter, CastTensorPtrConverter>( patterns.getContext()); } diff --git a/python/examples/conftest.py b/python/examples/conftest.py index 441f0033..0cf8acee 100644 --- a/python/examples/conftest.py +++ b/python/examples/conftest.py @@ -32,7 +32,6 @@ def device(request): "test_tensor_atomic_rmw_block", "test_nested_if_else_return", "test_ptx_cast", - "test_compare_op", "test_maxnreg", "test_join", "test_join_scalars", diff --git a/test/Conversion/TritonToStructured/cast_ptr.mlir b/test/Conversion/TritonToStructured/cast_ptr.mlir new file mode 100644 index 00000000..b63e14fc --- /dev/null +++ b/test/Conversion/TritonToStructured/cast_ptr.mlir @@ -0,0 +1,44 @@ +// RUN: triton-shared-opt --triton-to-structured %s | FileCheck %s + +module { + tt.func public @kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %1 = tt.splat %arg1 : !tt.ptr -> tensor<128x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<128x!tt.ptr>, tensor<128xi32> + %3 = tt.load %2 : tensor<128x!tt.ptr> + %4 = tt.splat %arg2 : !tt.ptr -> tensor<128x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<128x!tt.ptr>, tensor<128xi32> + %6 = tt.load %5 : tensor<128x!tt.ptr> + %7 = arith.extsi %3 : tensor<128xi8> to tensor<128xi32> + %8 = arith.cmpi eq, %7, %6 : tensor<128xi32> + %9 = tt.splat %arg0 : !tt.ptr -> tensor<128x!tt.ptr> + %10 = tt.addptr %9, %0 : tensor<128x!tt.ptr>, tensor<128xi32> + %11 = tt.bitcast %10 : tensor<128x!tt.ptr> -> tensor<128x!tt.ptr> + %12 = arith.extui %8 : tensor<128xi1> to tensor<128xi8> + tt.store %11, %12 : tensor<128x!tt.ptr> + tt.return + } +} + +// CHECK: module { +// CHECK: tt.func public @kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { +// CHECK: [[VAR_0:%.+]] = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> +// CHECK: [[VAR_1:%.+]] = tt.splat %arg1 : !tt.ptr -> tensor<128x!tt.ptr> +// CHECK: [[VAR_2:%.+]] = tts.make_tptr %arg1 to sizes: [128], strides: [1], offsets: [0], shape: [0], order: [] : to tensor<128x!tt.ptr> +// CHECK: [[VAR_3:%.+]] = tt.addptr [[VAR_1]], [[VAR_0]] : tensor<128x!tt.ptr>, tensor<128xi32> +// CHECK: [[VAR_4:%.+]] = "tts.load"([[VAR_2]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<128x!tt.ptr>) -> tensor<128xi8> +// CHECK: [[VAR_5:%.+]] = tt.splat %arg2 : !tt.ptr -> tensor<128x!tt.ptr> +// CHECK: [[VAR_6:%.+]] = tts.make_tptr %arg2 to sizes: [128], strides: [1], offsets: [0], shape: [0], order: [] : to tensor<128x!tt.ptr> +// CHECK: [[VAR_7:%.+]] = tt.addptr [[VAR_5]], [[VAR_0]] : tensor<128x!tt.ptr>, tensor<128xi32> +// CHECK: [[VAR_8:%.+]] = "tts.load"([[VAR_6]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<128x!tt.ptr>) -> tensor<128xi32> +// CHECK: [[VAR_9:%.+]] = arith.extsi [[VAR_4]] : tensor<128xi8> to tensor<128xi32> +// CHECK: [[VAR_10:%.+]] = arith.cmpi eq, [[VAR_9]], [[VAR_8]] : tensor<128xi32> +// CHECK: [[VAR_11:%.+]] = tt.splat %arg0 : !tt.ptr -> tensor<128x!tt.ptr> +// CHECK: [[VAR_12:%.+]] = tts.make_tptr %arg0 to sizes: [128], strides: [1], offsets: [0], shape: [0], order: [] : to tensor<128x!tt.ptr> +// CHECK: [[VAR_13:%.+]] = tt.addptr [[VAR_11]], [[VAR_0]] : tensor<128x!tt.ptr>, tensor<128xi32> +// CHECK: [[VAR_14:%.+]] = tts.cast_tptr [[VAR_12]] : tensor<128x!tt.ptr> -> tensor<128x!tt.ptr> +// CHECK: [[VAR_15:%.+]] = arith.extui [[VAR_10]] : tensor<128xi1> to tensor<128xi8> +// CHECK: "tts.store"([[VAR_14]], [[VAR_15]]) <{static_mask_dims = array}> : (tensor<128x!tt.ptr>, tensor<128xi8>) -> () +// CHECK: tt.return +// CHECK: } +// CHECK: } \ No newline at end of file