Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loop collapsing with fusion #2618

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions src/Conversion/ONNXToKrnl/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1602,6 +1602,34 @@ class OpFusionHelper {
// Find the fusible op chain from the root op
void findFusibleOps();

// Decide how many loops should be collapsed with the outermost loop
// The purpose of loop collapsing is to increase parallelism without
// using nested parallel loops.
// Though loops can always be collapsed into one, it may introduce
// overhead for certain reference pattern. For elementwise operation,
// extra computation is needed for the broadcasted dim if the corresponding
// loop is collapsed.
// The current decision is based on broadcast pattern of all operands in the
// fused ops.
// collapsibleLoops = k means the top k+1 loops can be collapsed into one
// loop.
// k == -1 means that the index needs special handling even if no loops are
// collapsed. There are broadcast in the first dimension.
// ToFix: should -1 be allowed/needed?
int collapsibleLoops;
void decideCollapsibleLoops();

// Detect which dims of operand is broadcasted to output.
// It is assumed that shape of operand is within the output.
// In the return, the one in the ith bit means that there is a broadcast for
// the ith dim of the output
// ToFix: move this function into DimAnalysis
uint64_t broadcastDimsForOneOperand(Value operand, Value output);

// Collect (OR) all the broadcast dims for all the operands for the rootOp
// and the fusible ops
uint64_t broadcastDimsForAll();

bool isFusibleListEmpty() { return fusibleOps.size() == 0; }

// The final output type after fusion.
Expand Down Expand Up @@ -1827,6 +1855,84 @@ void OpFusionHelper::findFusibleOps() {
});
}

uint64_t OpFusionHelper::broadcastDimsForAll() {

Value output = rootOp->getResults()[0];
// Type outputType = rootOp->getResultTypes()[0];
// uint64_t rank = getRank(outputType);
Operation *op = rootOp;
uint64_t result = 0;

size_t fusibleOpIndex = 0;
do {
// check the operands of this op.
for (Value operand : op->getOperands()) {
// Check the MDBroadcast op
if (!isa<ONNXAddOp, ONNXAndOp, ONNXBitwiseAndOp, ONNXBitwiseXorOp,
ONNXBitShiftOp, ONNXDivOp, ONNXEqualOp, ONNXGreaterOp, ONNXLessOp,
ONNXMaxOp, ONNXMeanOp, ONNXMinOp, ONNXModOp, ONNXMulOp, ONNXOrOp,
ONNXPowOp, ONNXSubOp, ONNXSumOp, ONNXXorOp>(op))
continue;
// ToFix: some our elementwise op may have issue for collapsing.
// For example, the other operands of ONNXClipOp
result |= broadcastDimsForOneOperand(operand, output);
}
} while ((fusibleOpIndex < fusibleOps.size()) &&
(op = fusibleOps[fusibleOpIndex++]));
return result;
}

uint64_t OpFusionHelper::broadcastDimsForOneOperand(
Value operand, Value output) {
uint64_t result = 0;
uint64_t outputRank = getRank(output.getType());
uint64_t operandRank = getRank(operand.getType());
if (operandRank == 1) {
ArrayRef<int64_t> operandShape = getShape(operand.getType());
if (operandShape[0] == 1)
// Special case for one element: the index will always be [0]
return result;
}

int64_t diff = outputRank - operandRank;
for (int64_t i = 0; i < (int64_t)outputRank; i++) {
if (i < diff) {
result |= (0x01) << i;
} else if (!dimAnalysis->sameDim(operand, i - diff, output, i)) {
result |= (0x01) << i;
}
}
return result;
}

// Current implementation only checks whether the dimension is broadcasted.
// Loop collapsing stops at the first dimension that has broadcast
// Could be further improved if the benenfit of collapsing is higher than
// the overhead of index reconstruction.
void OpFusionHelper::decideCollapsibleLoops() {
uint64_t broadcastBits = broadcastDimsForAll();
Type outputType = rootOp->getResultTypes()[0];
int rank = (int)getRank(outputType);

// LLVM_DEBUG(llvm::dbgs() << "broadcast bits " << broadcastBits << "\n";);
// llvm::dbgs() << "\nFusion " << fusibleOps.size() << "\n";
//(llvm::dbgs() << "broadcast bits " << broadcastBits << "\n");
// llvm::dbgs() << "number of loops " << rank << "\n";

for (int i = 0; i < rank; i++) {
if ((broadcastBits & (0x01 << i)) != 0) {
collapsibleLoops = i - 1;
///*LLVM_DEBUG*/ (
// llvm::dbgs() << "loop collaping: " << collapsibleLoops << "\n");
return;
}
}

collapsibleLoops = rank - 1;
///*LLVM_DEBUG*/ (
// llvm::dbgs() << "loop collaping: " << collapsibleLoops << "\n");
}

// After fusion, the only store is for the last Op.
// Therefore, the allocation should be the output of the last Op
MemRefType OpFusionHelper::getOutputType(MemRefType outputType) {
Expand Down Expand Up @@ -2028,6 +2134,7 @@ struct ONNXElementwiseUnaryOpLowering
// Try to fuse the unary elementwise consumers
OpFusionHelper opFusionHelper(rewriter, op, dimAnalysis);
opFusionHelper.findFusibleOps();
opFusionHelper.decideCollapsibleLoops();
outputMemRefType = opFusionHelper.getOutputType(outputMemRefType);

// Insert an allocation for the result of this operation.
Expand Down Expand Up @@ -2196,6 +2303,7 @@ struct ONNXElementwiseBinaryOpLowering
// Try to fuse the unary elementwise consumers
OpFusionHelper opFusionHelper(rewriter, op, dimAnalysis);
opFusionHelper.findFusibleOps();
opFusionHelper.decideCollapsibleLoops();
outputMemRefType = opFusionHelper.getOutputType(outputMemRefType);

// Insert an allocation and deallocation for the result of this operation.
Expand Down Expand Up @@ -2359,6 +2467,7 @@ struct ONNXElementwiseVariadicOpLowering
// Try to fuse the unary elementwise consumers
OpFusionHelper opFusionHelper(rewriter, op, dimAnalysis);
opFusionHelper.findFusibleOps();
opFusionHelper.decideCollapsibleLoops();
outputMemRefType = opFusionHelper.getOutputType(outputMemRefType);

// Insert an allocation and deallocation for the result of this operation.
Expand Down
Loading