diff --git a/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp b/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp index 85cdfbe38c..f4d509235b 100644 --- a/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp +++ b/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp @@ -1602,6 +1602,34 @@ class OpFusionHelper { // Find the fusible op chain from the root op void findFusibleOps(); + // Decide how many loops should be collapsed with the outermost loop + // The purpose of loop collapsing is to increase parallelism without + // using nested parallel loops. + // Though loops can always be collapsed into one, it may introduce + // overhead for certain reference pattern. For elementwise operation, + // extra computation is needed for the broadcasted dim if the corresponding + // loop is collapsed. + // The current decision is based on broadcast pattern of all operands in the + // fused ops. + // collapsibleLoops = k means the top k+1 loops can be collapsed into one + // loop. + // k == -1 means that the index needs special handling even if no loops are + // collapsed. There are broadcast in the first dimension. + // ToFix: should -1 be allowed/needed? + int collapsibleLoops; + void decideCollapsibleLoops(); + + // Detect which dims of operand is broadcasted to output. + // It is assumed that shape of operand is within the output. + // In the return, the one in the ith bit means that there is a broadcast for + // the ith dim of the output + // ToFix: move this function into DimAnalysis + uint64_t broadcastDimsForOneOperand(Value operand, Value output); + + // Collect (OR) all the broadcast dims for all the operands for the rootOp + // and the fusible ops + uint64_t broadcastDimsForAll(); + bool isFusibleListEmpty() { return fusibleOps.size() == 0; } // The final output type after fusion. @@ -1827,6 +1855,84 @@ void OpFusionHelper::findFusibleOps() { }); } +uint64_t OpFusionHelper::broadcastDimsForAll() { + + Value output = rootOp->getResults()[0]; + // Type outputType = rootOp->getResultTypes()[0]; + // uint64_t rank = getRank(outputType); + Operation *op = rootOp; + uint64_t result = 0; + + size_t fusibleOpIndex = 0; + do { + // check the operands of this op. + for (Value operand : op->getOperands()) { + // Check the MDBroadcast op + if (!isa(op)) + continue; + // ToFix: some our elementwise op may have issue for collapsing. + // For example, the other operands of ONNXClipOp + result |= broadcastDimsForOneOperand(operand, output); + } + } while ((fusibleOpIndex < fusibleOps.size()) && + (op = fusibleOps[fusibleOpIndex++])); + return result; +} + +uint64_t OpFusionHelper::broadcastDimsForOneOperand( + Value operand, Value output) { + uint64_t result = 0; + uint64_t outputRank = getRank(output.getType()); + uint64_t operandRank = getRank(operand.getType()); + if (operandRank == 1) { + ArrayRef operandShape = getShape(operand.getType()); + if (operandShape[0] == 1) + // Special case for one element: the index will always be [0] + return result; + } + + int64_t diff = outputRank - operandRank; + for (int64_t i = 0; i < (int64_t)outputRank; i++) { + if (i < diff) { + result |= (0x01) << i; + } else if (!dimAnalysis->sameDim(operand, i - diff, output, i)) { + result |= (0x01) << i; + } + } + return result; +} + +// Current implementation only checks whether the dimension is broadcasted. +// Loop collapsing stops at the first dimension that has broadcast +// Could be further improved if the benenfit of collapsing is higher than +// the overhead of index reconstruction. +void OpFusionHelper::decideCollapsibleLoops() { + uint64_t broadcastBits = broadcastDimsForAll(); + Type outputType = rootOp->getResultTypes()[0]; + int rank = (int)getRank(outputType); + + // LLVM_DEBUG(llvm::dbgs() << "broadcast bits " << broadcastBits << "\n";); + // llvm::dbgs() << "\nFusion " << fusibleOps.size() << "\n"; + //(llvm::dbgs() << "broadcast bits " << broadcastBits << "\n"); + // llvm::dbgs() << "number of loops " << rank << "\n"; + + for (int i = 0; i < rank; i++) { + if ((broadcastBits & (0x01 << i)) != 0) { + collapsibleLoops = i - 1; + ///*LLVM_DEBUG*/ ( + // llvm::dbgs() << "loop collaping: " << collapsibleLoops << "\n"); + return; + } + } + + collapsibleLoops = rank - 1; + ///*LLVM_DEBUG*/ ( + // llvm::dbgs() << "loop collaping: " << collapsibleLoops << "\n"); +} + // After fusion, the only store is for the last Op. // Therefore, the allocation should be the output of the last Op MemRefType OpFusionHelper::getOutputType(MemRefType outputType) { @@ -2028,6 +2134,7 @@ struct ONNXElementwiseUnaryOpLowering // Try to fuse the unary elementwise consumers OpFusionHelper opFusionHelper(rewriter, op, dimAnalysis); opFusionHelper.findFusibleOps(); + opFusionHelper.decideCollapsibleLoops(); outputMemRefType = opFusionHelper.getOutputType(outputMemRefType); // Insert an allocation for the result of this operation. @@ -2196,6 +2303,7 @@ struct ONNXElementwiseBinaryOpLowering // Try to fuse the unary elementwise consumers OpFusionHelper opFusionHelper(rewriter, op, dimAnalysis); opFusionHelper.findFusibleOps(); + opFusionHelper.decideCollapsibleLoops(); outputMemRefType = opFusionHelper.getOutputType(outputMemRefType); // Insert an allocation and deallocation for the result of this operation. @@ -2359,6 +2467,7 @@ struct ONNXElementwiseVariadicOpLowering // Try to fuse the unary elementwise consumers OpFusionHelper opFusionHelper(rewriter, op, dimAnalysis); opFusionHelper.findFusibleOps(); + opFusionHelper.decideCollapsibleLoops(); outputMemRefType = opFusionHelper.getOutputType(outputMemRefType); // Insert an allocation and deallocation for the result of this operation.