From 6cba93b16ef4f1bf7ec30481fdd0422a4c33b15d Mon Sep 17 00:00:00 2001 From: zjgarvey <47986913+zjgarvey@users.noreply.github.com> Date: Fri, 17 May 2024 14:18:57 -0500 Subject: [PATCH] [ONNX][TorchToLinalg] Add support for dynamic dims in Interpolate lowering (#3351) Addresses [Shark-Turbine #196](https://github.com/nod-ai/SHARK-TestSuite/issues/196) Related tracker [Shark-Turbine #566](https://github.com/nod-ai/SHARK-Turbine/issues/566) Related onnx.Resize issues [Shark-Turbine #616](https://github.com/nod-ai/SHARK-Turbine/issues/616) --- .../TorchToLinalg/Uncategorized.cpp | 26 +++++++------------ projects/pt1/e2e_testing/xfail_sets.py | 3 --- test/Conversion/TorchToLinalg/resize.mlir | 12 +++------ 3 files changed, 13 insertions(+), 28 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index e369df0d066e..76a4c8656b54 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -2912,11 +2912,13 @@ class ConvertInterpolateOp auto inputType = input.getType().cast(); auto inputRank = inputType.getRank(); - if (inputType.isDynamicDim(2) || inputType.isDynamicDim(3)) { - return rewriter.notifyMatchFailure(op, "error: Dynamic dim on resize op"); - } - SmallVector outputSizeIntValues; + Value inputSizeH = getDimOp(rewriter, loc, input, 2); + inputSizeH = rewriter.create( + loc, rewriter.getIntegerType(64), inputSizeH); + Value inputSizeW = getDimOp(rewriter, loc, input, 3); + inputSizeW = rewriter.create( + loc, rewriter.getIntegerType(64), inputSizeW); if (!op.getScaleFactor().getType().isa()) { SmallVector ScaleFactorTorchFloat; @@ -2927,8 +2929,6 @@ class ConvertInterpolateOp SmallVector ScaleFactorFloatValues; ScaleFactorFloatValues = getTypeConvertedValues( rewriter, loc, getTypeConverter(), ScaleFactorTorchFloat); - Value inputSizeH = rewriter.create( - loc, rewriter.getI64IntegerAttr(inputType.getShape()[2])); Value inputHFP = rewriter.create( loc, rewriter.getF32Type(), inputSizeH); Value scale = rewriter.create(loc, inputHFP.getType(), @@ -2938,8 +2938,6 @@ class ConvertInterpolateOp outputH = rewriter.create(loc, rewriter.getI64Type(), outputH); - Value inputSizeW = rewriter.create( - loc, rewriter.getI64IntegerAttr(inputType.getShape()[3])); Value inputWFP = rewriter.create( loc, rewriter.getF32Type(), inputSizeW); scale = rewriter.create(loc, inputWFP.getType(), @@ -2960,11 +2958,9 @@ class ConvertInterpolateOp outputSizeIntValues = getTypeConvertedValues( rewriter, loc, getTypeConverter(), outputSizeTorchInt); } - int hDimOffset = 2; - SmallVector dims = getTensorSizes(rewriter, loc, input); - dims[hDimOffset] = castIntToIndex(rewriter, loc, outputSizeIntValues[0]); - dims[hDimOffset + 1] = - castIntToIndex(rewriter, loc, outputSizeIntValues[1]); + SmallVector dims = getTensorSizesUntilDim(rewriter, loc, input, 1); + dims.push_back(castIntToIndex(rewriter, loc, outputSizeIntValues[0])); + dims.push_back(castIntToIndex(rewriter, loc, outputSizeIntValues[1])); Value outTensor = rewriter.create( loc, getAsOpFoldResult(dims), inputType.getElementType()); @@ -2983,10 +2979,6 @@ class ConvertInterpolateOp [&](OpBuilder &b, Location loc, ValueRange args) { Value outputSizeH = outputSizeIntValues[0]; Value outputSizeW = outputSizeIntValues[1]; - Value inputSizeH = b.create( - loc, b.getI64IntegerAttr(inputType.getShape()[2])); - Value inputSizeW = b.create( - loc, b.getI64IntegerAttr(inputType.getShape()[3])); Value retVal; if (mode == "nearest") { retVal = diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index f7904fc7f85c..72c495b1ba0d 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2607,9 +2607,6 @@ "BernoulliTensorModule_basic", # Failure - onnx_lowering: onnx.ReduceProd "ReduceProdDimIntFloatModule_basic", - # Failure - onnx_lowering: onnx.Resize - "UpSampleNearest2dDynamicSize_basic", - "UpSampleNearest2dStaticSize_basic", # Failure - onnx_lowering: onnx.ScatterElements "ScatterReduceFloatMaxModuleIncludeSelf", "ScatterReduceFloatMinModuleIncludeSelf", diff --git a/test/Conversion/TorchToLinalg/resize.mlir b/test/Conversion/TorchToLinalg/resize.mlir index 480454b3f1fc..9850a5fdabd6 100644 --- a/test/Conversion/TorchToLinalg/resize.mlir +++ b/test/Conversion/TorchToLinalg/resize.mlir @@ -4,15 +4,13 @@ func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4] ,si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[generic:.*]] = linalg.generic - // CHECK: %[[c2_i64:.*]] = arith.constant 2 : i64 - // CHECK: %[[c4_i64:.*]] = arith.constant 4 : i64 // CHECK: %[[cst:.*]] = arith.constant 1.001000e+00 : f32 // CHECK: %[[cst_4:.*]] = arith.constant 1.000000e+00 : f32 // CHECK: %[[cst_5:.*]] = arith.constant 5.000000e-01 : f32 // CHECK: %[[cst_6:.*]] = arith.constant 0.000000e+00 : f32 // CHECK: %[[x13:.*]] = linalg.index 2 : index // CHECK: %[[x14:.*]] = linalg.index 3 : index - // CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64]] : i64 to f32 + // CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64:.*]] : i64 to f32 // CHECK: %[[x16:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32 // CHECK: %[[x17:.*]] = arith.divf %[[x16]], %[[x15]] : f32 // CHECK: %[[x18:.*]] = arith.index_cast %[[x13]] : index to i64 @@ -23,7 +21,7 @@ func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: // CHECK: %[[x23:.*]] = arith.maximumf %[[x22]], %[[cst_6]] : f32 // CHECK: %[[x24:.*]] = arith.subf %[[x15]], %[[cst]] : f32 // CHECK: %[[x25:.*]] = arith.minimumf %[[x23]], %[[x24]] : f32 - // CHECK: %[[x26:.*]] = arith.sitofp %[[c4_i64]] : i64 to f32 + // CHECK: %[[x26:.*]] = arith.sitofp %[[c4_i64:.*]] : i64 to f32 // CHECK: %[[x27:.*]] = arith.sitofp %[[x7:.*]] : i64 to f32 // CHECK: %[[x28:.*]] = arith.divf %[[x27]], %[[x26]] : f32 // CHECK: %[[x29:.*]] = arith.index_cast %[[x14]] : index to i64 @@ -96,12 +94,10 @@ func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[GENERIC:.*]] = linalg.generic - // CHECK: %[[c2_i64:.*]] = arith.constant 2 : i64 - // CHECK: %[[c4_i64:.*]] = arith.constant 4 : i64 // CHECK: %[[x13:.*]] = linalg.index 2 : index // CHECK: %[[x14:.*]] = linalg.index 3 : index - // CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64]] : i64 to f32 - // CHECK: %[[x16:.*]] = arith.sitofp %[[c4_i64]] : i64 to f32 + // CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64:.*]] : i64 to f32 + // CHECK: %[[x16:.*]] = arith.sitofp %[[c4_i64:.*]] : i64 to f32 // CHECK: %[[x19:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32 // CHECK: %[[x20:.*]] = arith.sitofp %[[x7:.*]] : i64 to f32 // CHECK: %[[x21:.*]] = arith.divf %[[x19]], %[[x15]] : f32