From d2869ce15bcd09ce652ab8092947cfbc971456ea Mon Sep 17 00:00:00 2001 From: jinchen62 Date: Wed, 13 Nov 2024 12:54:31 -0800 Subject: [PATCH] Switch to torch decomposition from linalg lowering --- lib/Conversion/TorchToLinalg/CMakeLists.txt | 1 - .../TorchToLinalg/PopulatePatterns.h | 3 - .../TorchToLinalg/TorchToLinalg.cpp | 2 - .../TorchToLinalg/TorchvisionOp.cpp | 178 ----------- .../Torch/Transforms/DecomposeComplexOps.cpp | 287 +++++++++++++++++- 5 files changed, 286 insertions(+), 185 deletions(-) delete mode 100644 lib/Conversion/TorchToLinalg/TorchvisionOp.cpp diff --git a/lib/Conversion/TorchToLinalg/CMakeLists.txt b/lib/Conversion/TorchToLinalg/CMakeLists.txt index ad9b5cfe30c0..fbf66556eadc 100644 --- a/lib/Conversion/TorchToLinalg/CMakeLists.txt +++ b/lib/Conversion/TorchToLinalg/CMakeLists.txt @@ -8,7 +8,6 @@ add_mlir_conversion_library(TorchMLIRTorchToLinalg TensorConstructors.cpp TensorScalarInterop.cpp TorchToLinalg.cpp - TorchvisionOp.cpp Uncategorized.cpp Utils.cpp diff --git a/lib/Conversion/TorchToLinalg/PopulatePatterns.h b/lib/Conversion/TorchToLinalg/PopulatePatterns.h index 8f798220327c..56691c82c7c1 100644 --- a/lib/Conversion/TorchToLinalg/PopulatePatterns.h +++ b/lib/Conversion/TorchToLinalg/PopulatePatterns.h @@ -63,9 +63,6 @@ void populateIndirectDataMovementPatternsAndLegality( void populateTensorConstructorsPatternsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target); -void populateTorchvisionPatternsAndLegality(TypeConverter &typeConverter, - RewritePatternSet &patterns, - ConversionTarget &target); } // namespace torch_to_linalg } // namespace torch diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index 0157da71762d..01b1d4b973b6 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -81,8 +81,6 @@ class ConvertTorchToLinalg typeConverter, patterns, target); torch_to_linalg::populateTensorConstructorsPatternsAndLegality( typeConverter, patterns, target); - torch_to_linalg::populateTorchvisionPatternsAndLegality(typeConverter, - patterns, target); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/lib/Conversion/TorchToLinalg/TorchvisionOp.cpp b/lib/Conversion/TorchToLinalg/TorchvisionOp.cpp deleted file mode 100644 index ecea50728a2a..000000000000 --- a/lib/Conversion/TorchToLinalg/TorchvisionOp.cpp +++ /dev/null @@ -1,178 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// Also available under a BSD-style license. See LICENSE. -// -//===----------------------------------------------------------------------===// - -#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" - -#include "PopulatePatterns.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -// #include "torch-mlir/Conversion/TorchToLinalg/Utils.h" -#include "torch-mlir/Conversion/Utils/Utils.h" -// #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" -#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" -// #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" -#include "torch-mlir/Dialect/Torch/Utils/Utils.h" - -using namespace mlir; -using namespace mlir::torch; -using namespace mlir::torch::Torch; - -namespace { -static Value calculateIoU(OpBuilder &b, Location loc, Value box1, Value box2) { - // box format: [x1, y1, x2, y2] with 0 <= x1 < x2 and 0 <= y1 < y2 - Value idx0 = b.create(loc, 0); - Value idx1 = b.create(loc, 1); - Value idx2 = b.create(loc, 2); - Value idx3 = b.create(loc, 3); - Value b1x1 = b.create(loc, box1, ValueRange{idx0}); - Value b1y1 = b.create(loc, box1, ValueRange{idx1}); - Value b1x2 = b.create(loc, box1, ValueRange{idx2}); - Value b1y2 = b.create(loc, box1, ValueRange{idx3}); - Value b2x1 = b.create(loc, box2, ValueRange{idx0}); - Value b2y1 = b.create(loc, box2, ValueRange{idx1}); - Value b2x2 = b.create(loc, box2, ValueRange{idx2}); - Value b2y2 = b.create(loc, box2, ValueRange{idx3}); - - // Calculate intersection width and height - Value intersectX1 = b.create(loc, b1x1, b2x1); - Value intersectY1 = b.create(loc, b1y1, b2y1); - Value intersectX2 = b.create(loc, b1x2, b2x2); - Value intersectY2 = b.create(loc, b1y2, b2y2); - // Width = max(0, intersectX2 - intersectX1) - Value zero = b.create(loc, b.getF32Type(), b.getF32FloatAttr(0.0)); - Value width = b.create(loc, intersectX2, intersectX1); - width = b.create(loc, width, zero); - // Height = max(0, intersectY2 - intersectY1) - Value height = b.create(loc, intersectY2, intersectY1); - height = b.create(loc, height, zero); - // Intersection area = width * height - Value intersectionArea = b.create(loc, width, height); - - // Calculate area of box1: (b1x2 - b1x1) * (b1y2 - b1y1) - Value width1 = b.create(loc, b1x2, b1x1); - Value height1 = b.create(loc, b1y2, b1y1); - Value area1 = b.create(loc, width1, height1); - // Calculate area of box2: (b2x2 - b2x1) * (b2y2 - b2y1) - Value width2 = b.create(loc, b2x2, b2x1); - Value height2 = b.create(loc, b2y2, b2y1); - Value area2 = b.create(loc, width2, height2); - // Union area = area1 + area2 - intersectionArea - Value unionArea = b.create(loc, area1, area2); - unionArea = b.create(loc, unionArea, intersectionArea); - - return b.create(loc, intersectionArea, unionArea); -} - -class ConvertTorchvisionNmsOp : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(TorchvisionNmsOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (failed(verifyLinalgCompatibleTypes(op, rewriter))) { - return failure(); - } - - Location loc = op->getLoc(); - Value boxes = adaptor.getOperands()[0]; - Value scores = adaptor.getOperands()[1]; - Value iouThreshold = adaptor.getOperands()[2]; - - auto boxesType = cast(boxes.getType()); - auto scoresType = cast(scores.getType()); - if (!boxesType || !scoresType) { - return failure(); - } - - // Calculate IoU for each pair of boxes - Type boxesElementType = boxesType.getElementType(); - Value cst0 = rewriter.create(loc, rewriter.getZeroAttr(boxesElementType)); - int64_t boxesSize = boxesType.getShape()[0]; - Value iouEmpty = - rewriter.create(loc, ArrayRef{boxesSize, boxesSize}, boxesElementType); - Value iouOutput = - rewriter.create(loc, cst0, iouEmpty).getResult(0); - - - // AffineExpr d0, d1; - // bindDims(getContext(), d0, d1); - // auto c0 = rewriter.getAffineConstantExpr(0); - // auto map = AffineMap::get(2, 0, {d0, d1}, rewriter.getContext()); - // auto map1 = AffineMap::get(2, 0, {d0, d1}, rewriter.getContext()); - // auto map2 = AffineMap::get(2, 0, {d0, d1}, rewriter.getContext()); - // SmallVector indexingMaps = {map, map1, map2}; - AffineMap inputMap = AffineMap::get(2, 0, {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)}, rewriter.getContext()); - AffineMap outputMap = AffineMap::get(2, 0, {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)}, rewriter.getContext()); - SmallVector indexingMaps = {inputMap, inputMap, outputMap}; - // SmallVector iteratorTypes( - // 2, utils::IteratorType::parallel); - SmallVector iteratorTypes = { - utils::IteratorType::parallel, utils::IteratorType::reduction}; - // Create the linalg.generic operation - Value result = - rewriter - .create( - loc, - /*resultTypes=*/iouOutput.getType(), - /*inputs=*/ValueRange{boxes, boxes}, - /*outputs=*/iouOutput, - /*indexingMaps=*/indexingMaps, - /*iteratorTypes=*/iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value box1 = args[0], box2 = args[1], out = args[2]; - box1.getType().dump(); - box2.getType().dump(); - out.getType().dump(); - Value resultValue = calculateIoU(b, loc, box1, box2); - b.create(loc, resultValue); - }) - .getResult(0); - - - // Create a mask tensor where we mark suppressed boxes (0 for keep, 1 for suppress) - // Value maskEmpty = rewriter.create(loc, scoresType.getShape(), rewriter.getI1Type()); - // Value maskOutput = - // rewriter.create(loc, cst0, maskEmpty).getResult(0); - // maskOutput.dump(); - - // Value maskTensor = - // rewriter - // .create( - // loc, maskOutput.getType(), ValueRange{iouTensor, scores}, maskOutput, - // /*indexing_maps=*/AffineMap::inferFromExprList({ - // {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)} - // }), - // /*iterator_types=*/ArrayRef{"parallel", "parallel"}, - // [&](OpBuilder &b, Location loc, ValueRange args) { - // Value iouValue = args[0]; - // Value score = args[1]; - // // Check conditions to suppress based on IoU and score thresholds. - // Value isSuppressed = b.create(loc, arith::CmpFPredicate::OGT, iouValue, iouThreshold); - // b.create(loc, isSuppressed); - // }) - // .getResult(0); - - // // Convert the mask tensor to the desired output format (filtered boxes). - // rewriter.replaceOp(op, maskTensor); - - - llvm::outs() << "!!!\n"; - return success(); - } -}; -} // namespace - -void mlir::torch::torch_to_linalg::populateTorchvisionPatternsAndLegality( - TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target) { - MLIRContext *context = patterns.getContext(); - target.addIllegalOp(); - patterns.add(typeConverter, context); -} diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index efa93d31e8fe..27e87c1a8d21 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -10501,6 +10501,102 @@ class DecomposeAtenFloatPowerTensorTensorOp }; } // namespace +static Value calculateIoU(PatternRewriter &rewriter, Location loc, Value box1, + Value box2) { + // box format: 1x4xf32 with [x1, y1, x2, y2], 0 <= x1 < x2 and 0 <= y1 < y2 + Value cst0 = + rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + Value cst1 = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + Value cst2 = + rewriter.create(loc, rewriter.getI64IntegerAttr(2)); + Value cst3 = + rewriter.create(loc, rewriter.getI64IntegerAttr(3)); + auto scalarTensorType = rewriter.getType( + ArrayRef{1}, rewriter.getIntegerType(64, /*signed*/ true)); + Value cst0Tensor = rewriter.create( + loc, scalarTensorType, cst0); + Value cst1Tensor = rewriter.create( + loc, scalarTensorType, cst1); + Value cst2Tensor = rewriter.create( + loc, scalarTensorType, cst2); + Value cst3Tensor = rewriter.create( + loc, scalarTensorType, cst3); + auto boxesTensorType = dyn_cast(box1.getType()); + auto extractTy = rewriter.getType( + ArrayRef{1}, boxesTensorType.getDtype()); + Value b1x1 = rewriter.create( + loc, extractTy, box1, /*dim=*/cst1, /*index=*/cst0Tensor); + Value b1y1 = rewriter.create( + loc, extractTy, box1, /*dim=*/cst1, /*index=*/cst1Tensor); + Value b1x2 = rewriter.create( + loc, extractTy, box1, /*dim=*/cst1, /*index=*/cst2Tensor); + Value b1y2 = rewriter.create( + loc, extractTy, box1, /*dim=*/cst1, /*index=*/cst3Tensor); + Value b2x1 = rewriter.create( + loc, extractTy, box2, /*dim=*/cst1, /*index=*/cst0Tensor); + Value b2y1 = rewriter.create( + loc, extractTy, box2, /*dim=*/cst1, /*index=*/cst1Tensor); + Value b2x2 = rewriter.create( + loc, extractTy, box2, /*dim=*/cst1, /*index=*/cst2Tensor); + Value b2y2 = rewriter.create( + loc, extractTy, box2, /*dim=*/cst1, /*index=*/cst3Tensor); + + // Calculate intersection width and height + Value intersectX1 = + rewriter.create(loc, extractTy, b1x1, b2x1); + Value intersectY1 = + rewriter.create(loc, extractTy, b1y1, b2y1); + Value intersectX2 = + rewriter.create(loc, extractTy, b1x2, b2x2); + Value intersectY2 = + rewriter.create(loc, extractTy, b1y2, b2y2); + // Width = max(0, intersectX2 - intersectX1) + Value float0 = rewriter.create( + loc, rewriter.getF32FloatAttr(0.0)); + auto scalarFloatType = rewriter.getType( + ArrayRef{1}, rewriter.getF32Type()); + Value float0Tensor = rewriter.create( + loc, scalarFloatType, float0); + Value width = rewriter.create( + loc, extractTy, intersectX2, intersectX1, cst1); + width = rewriter.create(loc, extractTy, width, + float0Tensor); + // Height = max(0, intersectY2 - intersectY1) + Value height = rewriter.create( + loc, extractTy, intersectY2, intersectY1, cst1); + height = rewriter.create(loc, extractTy, height, + float0Tensor); + // Intersection area = width * height + Value intersectionArea = + rewriter.create(loc, extractTy, width, height); + + // Calculate area of box1: (b1x2 - b1x1) * (b1y2 - b1y1) + Value width1 = + rewriter.create(loc, extractTy, b1x2, b1x1, cst1); + Value height1 = + rewriter.create(loc, extractTy, b1y2, b1y1, cst1); + Value area1 = + rewriter.create(loc, extractTy, width1, height1); + // Calculate area of box2: (b2x2 - b2x1) * (b2y2 - b2y1) + Value width2 = + rewriter.create(loc, extractTy, b2x2, b2x1, cst1); + Value height2 = + rewriter.create(loc, extractTy, b2y2, b2y1, cst1); + Value area2 = + rewriter.create(loc, extractTy, width2, height2); + // Union area = area1 + area2 - intersectionArea + Value unionArea = rewriter.create(loc, extractTy, + area1, area2, cst1); + unionArea = rewriter.create(loc, extractTy, unionArea, + intersectionArea, cst1); + + Value iouTensor = rewriter.create( + loc, extractTy, intersectionArea, unionArea); + return rewriter.create( + loc, rewriter.getType(), iouTensor); +} + namespace { class DecomposeTorchvisionNmsOp : public OpRewritePattern { public: @@ -10509,8 +10605,194 @@ class DecomposeTorchvisionNmsOp : public OpRewritePattern { PatternRewriter &rewriter) const override { Location loc = op->getLoc(); MLIRContext *context = op->getContext(); - Value input = op.getSelf(); + Value boxes = op.getDets(); + Value scores = op.getScores(); + Value iouThreshold = op.getIouThreshold(); + + Value cst0 = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + Value cst1 = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + Value cstNone = rewriter.create(loc); + Value cstTrue = + rewriter.create(loc, rewriter.getBoolAttr(true)); + Value cstFalse = rewriter.create( + loc, rewriter.getBoolAttr(false)); + + // Sort scores in descending order + // Use the sorted indices to iterate boxes + auto scoresType = dyn_cast(scores.getType()); + auto sortIndicesType = scoresType.getWithSizesAndDtype( + scoresType.getOptionalSizes(), + IntegerType::get(context, 64, IntegerType::Signed)); + auto sortResult = rewriter.create( + loc, TypeRange({scores.getType(), sortIndicesType}), scores, + /*dim=*/cst0, /*descending=*/cstTrue); + + // Get number of boxes for the loop count + auto boxesTensorType = dyn_cast(boxes.getType()); + int64_t boxesSize = boxesTensorType.getSizes()[0]; + Value len = rewriter.create( + loc, rewriter.getI64IntegerAttr(boxesSize)); + + // Create a mask to mark if we keep the boxes + Value maskShapeList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(context)), + SmallVector{len}); + auto maskTy = + ValueTensorType::get(context, ArrayRef{boxesSize}, + rewriter.getIntegerType(64, /*signed*/ true)); + Value mask = rewriter.create( + loc, maskTy, maskShapeList, cstNone, cstNone, cstNone, cstNone); + + llvm::SmallVector sliceSizes = {1, 4}; + auto sliceTy = rewriter.getType( + sliceSizes, boxesTensorType.getDtype()); + + // 1. Loop through the boxes based on sorted indices + // 2. Check the mask if it's marked as suppressed + // 3. Loop through the rest boxes in sorted indices + // 4. Extract the coordinates of two boxes and calculate IoU + // 5. Mark the second box as suppressed if IOU is larger than threshold + auto loop1 = + rewriter.create(loc, maskTy, len, cstTrue, mask); + { + PatternRewriter::InsertionGuard guard(rewriter); + Block *rowLoopBody = rewriter.createBlock( + &loop1.getRegion(), loop1.getRegion().begin(), + TypeRange({rewriter.getType(), mask.getType()}), + {loc, loc}); + Value i = rowLoopBody->getArgument(0); + + // Extract the mask to check if the base box is suppressed + auto extractTy = rewriter.getType( + llvm::SmallVector{1}, rewriter.getIntegerType(64, true)); + Value extract = rewriter.create( + loc, extractTy, mask, /*dim=*/cst0, /*index=*/i); + Value scalar = rewriter.create( + loc, rewriter.getType(), extract); + Value iskept = rewriter.create( + loc, rewriter.getType(), scalar); + auto ifFilterOthers = + rewriter.create(loc, TypeRange({maskTy}), iskept); + { + PatternRewriter::InsertionGuard guard(rewriter); + Block *thenBlock = + rewriter.createBlock(&ifFilterOthers.getThenRegion(), + ifFilterOthers.getThenRegion().begin()); + + // Extract the index from sorted indices to get the coordinates + Value extractIdx1 = rewriter.create( + loc, extractTy, sortResult.getResults()[1], /*dim=*/cst0, + /*index=*/i); + Value idx1 = rewriter.create( + loc, rewriter.getType(), extractIdx1); + Value end1 = rewriter.create(loc, idx1, cst1); + Value slice1 = rewriter.create( + loc, sliceTy, boxes, + /*dim=*/cst0, /*start=*/idx1, /*end=*/end1, /*step=*/cst1); + + // Loop through the rest of boxes + auto loop2 = + rewriter.create(loc, maskTy, len, cstTrue, mask); + { + PatternRewriter::InsertionGuard guard(rewriter); + Block *colLoopBody = rewriter.createBlock( + &loop2.getRegion(), loop2.getRegion().begin(), + TypeRange({rewriter.getType(), mask.getType()}), + {loc, loc}); + + // Check if current index is out of range + Value j = colLoopBody->getArgument(0); + j = rewriter.create(loc, j, i); + j = rewriter.create(loc, j, cst1); + Value isInRange = rewriter.create(loc, j, len); + auto ifCalculateIou = rewriter.create( + loc, TypeRange({maskTy}), isInRange); + { + PatternRewriter::InsertionGuard guard(rewriter); + Block *thenBlock = + rewriter.createBlock(&ifCalculateIou.getThenRegion(), + ifCalculateIou.getThenRegion().begin()); + + // Extract the coordinates for the second box + Value extractIdx2 = rewriter.create( + loc, extractTy, sortResult.getResults()[1], /*dim=*/cst0, + /*index=*/j); + Value idx2 = rewriter.create( + loc, rewriter.getType(), extractIdx2); + Value end2 = rewriter.create(loc, idx2, cst1); + Value slice2 = rewriter.create( + loc, sliceTy, boxes, + /*dim=*/cst0, /*start=*/idx2, /*end=*/end2, /*step=*/cst1); + + // Calculate IoU and decide if suppress it + Value iou = calculateIoU(rewriter, loc, slice1, slice2); + Value isSuppressed = + rewriter.create(loc, iou, iouThreshold); + auto ifUnmask = rewriter.create( + loc, TypeRange({maskTy}), isSuppressed); + { + PatternRewriter::InsertionGuard guard(rewriter); + Block *thenBlock = rewriter.createBlock( + &ifUnmask.getThenRegion(), ifUnmask.getThenRegion().begin()); + + Value zerosShapeList = + rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(context)), + SmallVector{cst1}); + auto zeroTy = + ValueTensorType::get(context, ArrayRef{1}, + rewriter.getIntegerType(64, true)); + Value falseMask = rewriter.create( + loc, zeroTy, zerosShapeList, cstNone, cstNone, cstNone, + cstNone); + Value end3 = rewriter.create(loc, j, cst1); + Value thenMask = rewriter.create( + loc, maskTy, mask, falseMask, cst0, + /*start=*/j, /*end=*/end3, /*step=*/cst1); + rewriter.create(loc, thenMask); + } + { + PatternRewriter::InsertionGuard guard(rewriter); + Block *elseBlock = rewriter.createBlock( + &ifUnmask.getElseRegion(), ifUnmask.getElseRegion().begin()); + Value elseMask = mask; + rewriter.create(loc, elseMask); + } + + rewriter.create(loc, ifUnmask.getResult(0)); + } + { + PatternRewriter::InsertionGuard guard(rewriter); + Block *elseBlock = + rewriter.createBlock(&ifCalculateIou.getElseRegion(), + ifCalculateIou.getElseRegion().begin()); + Value elseMask = mask; + rewriter.create(loc, elseMask); + } + + rewriter.create( + loc, cstTrue, ifCalculateIou.getResult(0)); + } + rewriter.create(loc, loop2.getResult(0)); + } + { + PatternRewriter::InsertionGuard guard(rewriter); + Block *elseBlock = + rewriter.createBlock(&ifFilterOthers.getElseRegion(), + ifFilterOthers.getElseRegion().begin()); + Value elseMask = mask; + rewriter.create(loc, elseMask); + } + + rewriter.create(loc, cstTrue, + ifFilterOthers.getResult(0)); + } + + rewriter.replaceOpWithNewOp( + op, op.getType(), sortResult.getResults()[1], loop1.getResult(0)); return success(); } }; @@ -10797,6 +11079,9 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal< DecomposeAtenFMaxMinOp>(patterns); + // Torchvision ops + addPatternIfTargetOpIsIllegal(patterns); + GreedyRewriteConfig config; config.useTopDownTraversal = true; config.maxIterations = GreedyRewriteConfig::kNoLimit;