Skip to content

Commit

Permalink
Implementing lowering for calls (#164)
Browse files Browse the repository at this point in the history
+ removing the assertion which disallowed tensors of 1x1, I do not see
why it was set, it seems that removing does not produce an error, let me
know what you think. Thank you!

---------

Co-authored-by: Renat Idrisov <parsifal-47@users.noreply.github.com>
  • Loading branch information
parsifal-47 and parsifal-47 authored Nov 29, 2024
1 parent 3fe82cb commit 6aa82f1
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -860,6 +860,49 @@ struct BitcastConverter : public OpConversionPattern<triton::BitcastOp> {
}
};

struct CallConverter : public OpConversionPattern<triton::CallOp> {
using OpConversionPattern<triton::CallOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(triton::CallOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Value> args = adaptor.getOperands();

// We need to pass extra arguments added by addProgramInfo which are num_programs and program_ids
if (FuncOp parentFunc = op->getParentOfType<triton::FuncOp>()) {
SymbolRefAttr calleeAttr = op.getCalleeAttr();
StringRef calleeName = calleeAttr.getRootReference();

if (ModuleOp module = op->getParentOfType<ModuleOp>()) {
if (FuncOp calleeFunc = module.lookupSymbol<FuncOp>(calleeName)) {
size_t argsNeed = calleeFunc.getFunctionType().getInputs().size();
Block &entryBlock = parentFunc.front();
auto parentInputs = entryBlock.getArguments();
size_t argsParent = parentInputs.size();

if (argsNeed > args.size()) {
int missing = argsNeed - args.size();
for (int i = 0; i < missing; i++) {
args.push_back(parentInputs[parentInputs.size() - i - 1]);
}
}
}
}
}

auto call = rewriter.create<func::CallOp>(
op.getLoc(), op.getCallee(), op.getResultTypes(), args);

if (!call) {
op.emitError("Failed to create func::CallOp");
return failure();
}

rewriter.replaceOp(op, call);
return success();
}
};

struct FpToFpConverter : public OpConversionPattern<triton::FpToFpOp> {
using OpConversionPattern<triton::FpToFpOp>::OpConversionPattern;

Expand Down
9 changes: 8 additions & 1 deletion lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,14 +229,21 @@ class StructuredToMemrefPass
ConversionTarget target(getContext());
TritonFunctionSignatureConverter typeConverter;

// Update function signature to use memrefs
// Update function signatures and calls to use memrefs
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
return typeConverter.isSignatureLegal(op.getFunctionType());
});

populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
patterns, typeConverter);

target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
return typeConverter.isLegal(op.getResultTypes()) && typeConverter.isLegal(op.getOperandTypes());
});

populateFunctionOpInterfaceTypeConversionPattern<func::CallOp>(
patterns, typeConverter);

return applyPartialConversion(moduleOp, target, std::move(patterns));
}

Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TritonArithToLinalg/TritonArithToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ void mlir::triton::populateTritonArithToLinalgConversionPatterns(
patterns.add<MakeRangeConverter>(patterns.getContext());
patterns.add<ExpandDimsConverter>(patterns.getContext());
patterns.add<BitcastConverter>(patterns.getContext());
patterns.add<CallConverter>(patterns.getContext());
patterns.add<MulHiUIOpConverter>(patterns.getContext());
patterns.add<PreciseSqrtConverter>(patterns.getContext());
patterns.add<PreciseDivConverter>(patterns.getContext());
Expand Down
3 changes: 3 additions & 0 deletions lib/Conversion/TritonToLinalg/TritonToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ void mlir::triton::populateTritonToLinalgConversionPatterns(
unsigned int launchGridRank) {
populateFunctionOpInterfaceTypeConversionPattern<triton::FuncOp>(
patterns, typeConverter);
populateFunctionOpInterfaceTypeConversionPattern<triton::CallOp>(
patterns, typeConverter);

patterns.add<MetaOpConverter>(patterns.getContext());
patterns.add<StoreConverter>(patterns.getContext());
Expand All @@ -49,6 +51,7 @@ void mlir::triton::populateTritonToLinalgConversionPatterns(
patterns.add<MakeRangeConverter>(patterns.getContext());
patterns.add<ExpandDimsConverter>(patterns.getContext());
patterns.add<BitcastConverter>(patterns.getContext());
patterns.add<CallConverter>(patterns.getContext());
patterns.add<MulHiUIOpConverter>(patterns.getContext());
patterns.add<PreciseSqrtConverter>(patterns.getContext());
patterns.add<PreciseDivConverter>(patterns.getContext());
Expand Down
23 changes: 23 additions & 0 deletions test/Conversion/TritonToLinalg/call.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// RUN: triton-shared-opt --triton-to-linalg-experimental %s | FileCheck %s

module {
tt.func @_sum_combine__fp32() -> f32{
%0 = arith.constant 42.0 : f32
tt.return %0 : f32
}
tt.func @test() -> f32{
%0 = tt.call @_sum_combine__fp32() : () -> f32
tt.return %0 : f32
}
}

// CHECK: module {
// CHECK: func.func @_sum_combine__fp32(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32) -> f32 {
// CHECK: [[CST_:%.+]] = arith.constant 4.200000e+01 : f32
// CHECK: return [[CST_]] : f32
// CHECK: }
// CHECK: func.func @test(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32) -> f32 {
// CHECK: [[VAR_0_:%.+]] = call @_sum_combine__fp32(%arg5, %arg4, %arg3, %arg2, %arg1, %arg0) : (i32, i32, i32, i32, i32, i32) -> f32
// CHECK: return [[VAR_0_]] : f32
// CHECK: }
// CHECK: }

0 comments on commit 6aa82f1

Please sign in to comment.