From 09326efc8b39135b3b143c4210fc781e4d4c6731 Mon Sep 17 00:00:00 2001 From: jinchen62 Date: Wed, 25 Sep 2024 07:56:10 -0700 Subject: [PATCH] Generalize max_unpool lowering --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 25 ----------- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 5 --- lib/Conversion/TorchToLinalg/Pooling.cpp | 17 ++++---- projects/pt1/e2e_testing/xfail_sets.py | 2 + .../build_tools/torch_ods_gen.py | 1 - .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 16 ++++--- test/Conversion/TorchToLinalg/pooling.mlir | 42 +++++++++++++++++++ 7 files changed, 64 insertions(+), 44 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index c9329ccb895d..6a0912a0268f 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -7135,31 +7135,6 @@ def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [ }]; } -def Torch_AtenMaxUnpool2dOp : Torch_Op<"aten.max_unpool2d", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::max_unpool2d : (Tensor, Tensor, int[]) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$indices, - AnyTorchListOfTorchIntType:$output_size - ); - let results = (outs - AnyTorchOptionalTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenMaxUnpool2dOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); - } - void AtenMaxUnpool2dOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); - } - }]; -} - def Torch_AtenMaxPool2dWithIndicesOp : Torch_Op<"aten.max_pool2d_with_indices", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 168040d9b289..3cb219b57938 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -3430,11 +3430,6 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( SmallVector resultShape(resultType.getSizes()); Value resultShapeList = createConstantIntList(binder, rewriter, resultShape); - if (rank == 4) { - rewriter.replaceOpWithNewOp( - binder.op, resultType, data, indices, resultShapeList); - return success(); - } SmallVector padding, strides; if (binder.s64IntegerArrayAttr(padding, "pads", {})) diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index 90b5b2af77a8..b918615f8634 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -611,21 +611,22 @@ class ConvertAtenMaxUnpool3dOp final Value self = adaptor.getSelf(); auto selfType = cast(self.getType()); - ArrayRef inputSize = selfType.getShape().take_back(3); + size_t spatial = selfType.getRank() - 2; + ArrayRef inputSize = selfType.getShape().take_back(spatial); if (ShapedType::isDynamicShape(inputSize)) return rewriter.notifyMatchFailure(op, "input type must be of static shape"); Value indices = adaptor.getIndices(); auto indicesType = cast(indices.getType()); - if (inputSize != indicesType.getShape().take_back(3)) + if (inputSize != indicesType.getShape().take_back(spatial)) return rewriter.notifyMatchFailure(op, "input/indices shape mismatch"); auto resType = typeConverter->convertType(op.getType()); if (!resType) return rewriter.notifyMatchFailure(op, "invalid result type"); - ArrayRef inferredOutSize = resType.getShape().take_back(3); + ArrayRef inferredOutSize = resType.getShape().take_back(spatial); if (ShapedType::isDynamicShape(inferredOutSize)) return rewriter.notifyMatchFailure(op, "output type must be of static shape"); @@ -636,7 +637,7 @@ class ConvertAtenMaxUnpool3dOp final return rewriter.notifyMatchFailure(op, "only support constant int output"); - if (inferredOutSize != ArrayRef(output)) + if (inferredOutSize != ArrayRef(output).take_back(spatial)) return rewriter.notifyMatchFailure(op, "Invalid output size"); } SmallVector stride; @@ -652,12 +653,12 @@ class ConvertAtenMaxUnpool3dOp final // TODO: add support for asymmetric padding coming from "onnx.MaxUnpool" // (padding.size() == 6). - if (stride.size() != 3 || padding.size() != 3) + if (stride.size() != spatial || padding.size() != spatial) return rewriter.notifyMatchFailure( op, "stride and padding must be of size 3"); int64_t outRank = resType.getRank(); - int64_t NC = outRank - 3; + int64_t NC = outRank - spatial; for (auto &&[inDim, outDim, str, pad] : llvm::zip_equal(inputSize, inferredOutSize, stride, padding)) { @@ -694,7 +695,7 @@ class ConvertAtenMaxUnpool3dOp final // (e.g. pooling_input_size=5, kernel_size=2, stride=2, output_size=2) // pad self and indices tensors to avoid out of bounds access. SmallVector expectedInputShape = - llvm::to_vector(resType.getShape().drop_back(3)); + llvm::to_vector(resType.getShape().drop_back(spatial)); for (auto &&[str, pad, resSize] : llvm::zip_equal(stride, padding, inferredOutSize)) expectedInputShape.emplace_back(ceilDiv(resSize, str) + pad * 2); @@ -707,7 +708,7 @@ class ConvertAtenMaxUnpool3dOp final SmallVector low(outRank, 0); SmallVector high(NC, 0); for (auto &&[inpSize, outSize] : llvm::zip_equal( - inputSize, ArrayRef(expectedInputShape).take_back(3))) { + inputSize, ArrayRef(expectedInputShape).take_back(spatial))) { high.emplace_back(outSize - inpSize); } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 3b3e4611ea6b..d8fe2381f7a1 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -455,6 +455,8 @@ "IntFloatModule_basic", "IntImplicitModule_basic", "LenStrModule_basic", + # "MaxUnpool3dModulePad0_basic", + # "MaxUnpool3dModule_basic", "MulFloatModule_basic", "NativeGroupNormBackwardModule_basic", "NeFloatIntModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index f3227f29b5ce..955b99ce632d 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -621,7 +621,6 @@ def emit_with_mutating_variants(key, **kwargs): "aten::max_pool1d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)" ) emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)") - emit("aten::max_unpool2d : (Tensor, Tensor, int[]) -> (Tensor)") emit( "aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)", has_canonicalizer=True, diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 21be2a65f4a6..68119eb57841 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -1667,14 +1667,20 @@ func.func @test_lpnormalization(%arg0: !torch.vtensor<[3,4,5,6,7],f32>) -> !torc // ----- -// CHECK-LABEL: func.func @test_maxunpool_export_without_output_shape -func.func @test_maxunpool_export_without_output_shape(%arg0: !torch.vtensor<[1,1,2,2],f32>, %arg1: !torch.vtensor<[1,1,2,2],si64>) -> !torch.vtensor<[1,1,4,4],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +// CHECK-LABEL: func.func @test_maxunpool_2d_export_without_output_shape +func.func @test_maxunpool_2d_export_without_output_shape(%arg0: !torch.vtensor<[1,1,2,2],f32>, %arg1: !torch.vtensor<[1,1,2,2],si64>) -> !torch.vtensor<[1,1,4,4],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[INT1_0:.*]] = torch.constant.int 1 // CHECK: %[[INT4:.*]] = torch.constant.int 4 // CHECK: %[[INT4_0:.*]] = torch.constant.int 4 // CHECK: %[[OUTPUT_SHAPE:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1_0]], %[[INT4]], %[[INT4_0]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - // CHECK: %[[RESULT:.*]] = torch.aten.max_unpool2d %arg0, %arg1, %[[OUTPUT_SHAPE]] : !torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,1,2,2],si64>, !torch.list -> !torch.vtensor<[1,1,4,4],f32> + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT0_1:.*]] = torch.constant.int 0 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT0_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[INT2:.*]] = torch.constant.int 2 + // CHECK: %[[INT2_1:.*]] = torch.constant.int 2 + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[RESULT:.*]] = torch.aten.max_unpool3d %arg0, %arg1, %[[OUTPUT_SHAPE]], %[[STRIDE]], %[[PADDING]] : !torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,1,2,2],si64>, !torch.list, !torch.list, !torch.list -> !torch.vtensor<[1,1,4,4],f32> // return %[[RESULT]] : !torch.vtensor<[1,1,4,4],f32> %0 = torch.operator "onnx.MaxUnpool"(%arg0, %arg1) {torch.onnx.kernel_shape = [2 : si64, 2 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,1,2,2],si64>) -> !torch.vtensor<[1,1,4,4],f32> return %0 : !torch.vtensor<[1,1,4,4],f32> @@ -1682,8 +1688,8 @@ func.func @test_maxunpool_export_without_output_shape(%arg0: !torch.vtensor<[1,1 // ----- -// CHECK-LABEL: func.func @test_maxunpool3d_export_without_output_shape -func.func @test_maxunpool3d_export_without_output_shape(%arg0: !torch.vtensor<[1,1,2,2,2],f32>, %arg1: !torch.vtensor<[1,1,2,2,2],si64>) -> !torch.vtensor<[1,1,4,4,4],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +// CHECK-LABEL: func.func @test_maxunpool_3d_export_without_output_shape +func.func @test_maxunpool_3d_export_without_output_shape(%arg0: !torch.vtensor<[1,1,2,2,2],f32>, %arg1: !torch.vtensor<[1,1,2,2,2],si64>) -> !torch.vtensor<[1,1,4,4,4],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[INT1_0:.*]] = torch.constant.int 1 // CHECK: %[[INT4:.*]] = torch.constant.int 4 diff --git a/test/Conversion/TorchToLinalg/pooling.mlir b/test/Conversion/TorchToLinalg/pooling.mlir index 558c50c4f08f..bc5483ad60c6 100644 --- a/test/Conversion/TorchToLinalg/pooling.mlir +++ b/test/Conversion/TorchToLinalg/pooling.mlir @@ -95,3 +95,45 @@ func.func @forward_max_pool3d(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch. // CHECK: } -> tensor return %4 : !torch.vtensor<[?,?,?,?,?],f32> } + +// ----- + +// CHECK: #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 floordiv 2, d3 floordiv 2)> +// CHECK: #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func @forward_max_unpool2d +func.func @forward_max_unpool2d(%arg0: !torch.vtensor<[1,1,2,2],f32>, %arg1: !torch.vtensor<[1,1,2,2],si64>) -> !torch.vtensor<[1,1,4,4],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %int1 = torch.constant.int 1 + %int1_0 = torch.constant.int 1 + %int4 = torch.constant.int 4 + %int4_1 = torch.constant.int 4 + %0 = torch.prim.ListConstruct %int1, %int1_0, %int4, %int4_1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %int0 = torch.constant.int 0 + %int0_2 = torch.constant.int 0 + %1 = torch.prim.ListConstruct %int0, %int0_2 : (!torch.int, !torch.int) -> !torch.list + %int2 = torch.constant.int 2 + %int2_3 = torch.constant.int 2 + %2 = torch.prim.ListConstruct %int2, %int2_3 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.aten.max_unpool %arg0, %arg1, %0, %2, %1 : !torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,1,2,2],si64>, !torch.list, !torch.list, !torch.list -> !torch.vtensor<[1,1,4,4],f32> + + // CHECK: %[[INDICES:.*]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[1,1,2,2],si64> -> tensor<1x1x2x2xi64> + // CHECK: %[[INPUT:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[1,1,2,2],f32> -> tensor<1x1x2x2xf32> + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[DIM0:.*]] = tensor.dim %[[INPUT]], %[[C0]] : tensor<1x1x2x2xf32> + // CHECK: %[[C1:.*]] = arith.constant 1 : index + // CHECK: %[[DIM1:.*]] = tensor.dim %[[INPUT]], %[[C1]] : tensor<1x1x2x2xf32> + // CHECK: %[[SHAPE:.*]] = tensor.empty(%[[DIM0]], %[[DIM1]]) : tensor + // CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#map, #map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[INPUT]], %[[INDICES]] : tensor<1x1x2x2xf32>, tensor<1x1x2x2xi64>) outs(%[[SHAPE]] : tensor) { + // CHECK-NEXT: ^bb0(%[[CURRENT_VALUE:.*]]: f32, %[[CURRENT_INDEX:.*]]: i64, %[[OUT:.*]]: f32): + // CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK-NEXT: %[[INDEX_CAST:.*]] = arith.index_cast %[[CURRENT_INDEX:.*]] : i64 to index + // CHECK-NEXT: %[[INDEX2:.*]] = linalg.index 2 : index + // CHECK-NEXT: %[[INDEX3:.*]] = linalg.index 3 : index + // CHECK-NEXT: %[[C4:.*]] = arith.constant 4 : index + // CHECK-NEXT: %[[MULI:.*]] = arith.muli %[[INDEX2:.*]], %[[C4:.*]] : index + // CHECK-NEXT: %[[ADDI:.*]] = arith.addi %[[MULI:.*]], %[[INDEX3:.*]] : index + // CHECK-NEXT: %[[CMPI:.*]] = arith.cmpi eq, %[[INDEX_CAST:.*]], %[[ADDI:.*]] : index + // CHECK-NEXT: %[[SELECT:.*]] = arith.select %[[CMPI:.*]], %[[CURRENT_VALUE:.*]], %[[CST:.*]] : f32 + // CHECK-NEXT: linalg.yield %[[SELECT:.*]] : f32 + // CHECK: } -> tensor + return %3 : !torch.vtensor<[1,1,4,4],f32> +}