diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index f7fac538068a..889a5fe88704 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -1463,25 +1463,21 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( }); patterns.onOp( "CumSum", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Location loc = binder.getLoc(); Torch::ValueTensorType resultType; - Value operand; - Value axisTensor; + Value operand, axisTensor; + int64_t exclusive, reverse; if (binder.tensorOperands(operand, axisTensor) || + binder.s64IntegerAttr(exclusive, "exclusive", 0) || + binder.s64IntegerAttr(reverse, "reverse", 0) || binder.tensorResultType(resultType)) return failure(); - int64_t exclusive; - int64_t reverse; - // if bind succeeds and either is set, fail because not implemented - if (!binder.s64IntegerAttr(exclusive, "exclusive", 0)) - if (exclusive != 0) - return rewriter.notifyMatchFailure( - binder.op, "unsupported onnx.CumSum conversion: exclusive"); - if (!binder.s64IntegerAttr(reverse, "reverse", 0)) - if (reverse != 0) - return rewriter.notifyMatchFailure( - binder.op, "unsupported onnx.CumSum conversion: reverse"); + Torch::BaseTensorType resultTensorType = + cast(resultType); + if (!resultTensorType.hasDtype()) { + return rewriter.notifyMatchFailure( + binder.op, "expected result type to have a dtype"); + } // deal with neg axis: if (axis < 0) axis += rank int64_t rank = @@ -1489,30 +1485,45 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( Value rankVal = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), rank)); - Value zero = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); + Value cstZero = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + Value cstOne = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); Value axisScalar = rewriter.create( binder.getLoc(), rewriter.getType(), axisTensor); Value isNegative = rewriter.create( - binder.getLoc(), axisScalar, zero); + binder.getLoc(), axisScalar, cstZero); isNegative = rewriter.create(binder.getLoc(), isNegative); Value finalOffset = rewriter.create( binder.getLoc(), isNegative, rankVal); - Value dim = rewriter.create( + Value axis = rewriter.create( binder.getLoc(), axisScalar, finalOffset); + Value none = rewriter.create(binder.getLoc()); - Torch::BaseTensorType resultTensorType = - cast(resultType); - if (!resultTensorType.hasDtype()) { - return rewriter.notifyMatchFailure( - binder.op, "expected result type to have a dtype"); + Value res; + if (reverse) { + Value dims = rewriter.create( + binder.getLoc(), + rewriter.getType( + rewriter.getType()), + SmallVector{axis}); + Value flip = rewriter.create( + binder.getLoc(), resultType, operand, dims); + Value cumsum = rewriter.create( + binder.getLoc(), resultType, flip, axis, none); + res = rewriter.create(binder.getLoc(), resultType, + cumsum, dims); + } else { + res = rewriter.create( + binder.getLoc(), resultType, operand, axis, none); } - // resultTensorType.print(llvm::outs()); - Value none = rewriter.create(loc); - rewriter.replaceOpWithNewOp(binder.op, resultType, - operand, dim, none); + + if (exclusive) + res = rewriter.create( + binder.getLoc(), resultType, res, operand, cstOne); + rewriter.replaceOp(binder.op, res); return success(); }); patterns.onOp( diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 2f0a30b127d5..2615f7b7a36a 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -123,12 +123,20 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( if (gridShape[3] != 2) return rewriter.notifyMatchFailure(binder.op, "gridShape[3] expected to be 2"); - std::string mode; - if (binder.customOpNameStringAttr(mode, "mode", "linear")) + std::string iModeString; + int64_t iModeInt; + if (binder.customOpNameStringAttr(iModeString, "mode", "linear")) return rewriter.notifyMatchFailure(binder.op, "mode bind failure"); - if (mode != "linear" && mode != "bilinear") + + if (iModeString == "linear" || iModeString == "bilinear") { + iModeInt = 0; + } else if (iModeString == "nearest") { + iModeInt = 1; + } else { return rewriter.notifyMatchFailure( - binder.op, "currently only mode : linear supported"); + binder.op, "currently only mode : linear and nearest supported"); + } + std::string padding; if (binder.customOpNameStringAttr(padding, "padding_mode", "zeros")) return rewriter.notifyMatchFailure(binder.op, @@ -143,7 +151,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( Value interpolationMode = rewriter.create( binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + rewriter.getIntegerAttr(rewriter.getIntegerType(64), iModeInt)); + Value paddingMode = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 8d2bdac4b6ce..40c1ac01c858 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -2524,14 +2524,38 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern { Value result = b.create(loc, input, index); return result; }; - auto lambdaInter = [&](OpBuilder &b, Location loc, Value x, Value y, - Value d) -> Value { + + auto lambdaLinear = [&](OpBuilder &b, Location loc, Value x, Value y, + Value d) -> Value { Value dm = b.create(loc, oneFloat, d); Value ra = b.create(loc, x, dm); Value rb = b.create(loc, y, d); Value res = b.create(loc, ra, rb); return res; }; + + auto lambdaNearest = [&](OpBuilder &b, Location loc, Value x, Value y, + Value d) -> Value { + Value halfConst = rewriter.create( + loc, rewriter.getFloatAttr(floatType, 0.5)); + Value checkClosest = + b.create(loc, arith::CmpFPredicate::OLT, d, halfConst); + Value res = b.create(loc, checkClosest, x, y); + return res; + }; + + auto lambdaInterpolate = [&](OpBuilder &b, Location loc, Value iMode, + Value x, Value y, Value d) -> Value { + Value linear = lambdaLinear(b, loc, x, y, d); + Value nearest = lambdaNearest(b, loc, x, y, d); + Value zeroInt = + b.create(loc, b.getIntegerAttr(int64type, 0)); + Value checkMode = b.create(loc, arith::CmpIPredicate::eq, + iMode, zeroInt); + Value res = b.create(loc, checkMode, linear, nearest); + return res; + }; + auto resultType = getTypeConverter() ->convertType(op.getResult().getType()) .cast(); @@ -2545,6 +2569,7 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern { if (resultType.isDynamicDim(3)) resultSize.push_back(rewriter.create(loc, grid, 2)); Value alignCorners = adaptor.getAlignCorners(); + Value interMode = adaptor.getInterpolationMode(); Value resultFinal = rewriter.create(loc, resultType, resultSize); auto sGrid = rewriter.create( @@ -2633,10 +2658,12 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern { Value lw1a = b.create(loc, floatType, lower1); Value d1 = b.create(loc, result0, lw0a); Value d0 = b.create(loc, result1, lw1a); - Value resultScaled0 = lambdaInter(b, loc, result00b, result01b, d0); - Value resultScaled1 = lambdaInter(b, loc, result10b, result11b, d0); - Value resultScaled = - lambdaInter(b, loc, resultScaled0, resultScaled1, d1); + Value resultScaled0 = + lambdaInterpolate(b, loc, interMode, result00b, result01b, d0); + Value resultScaled1 = + lambdaInterpolate(b, loc, interMode, result10b, result11b, d0); + Value resultScaled = lambdaInterpolate( + b, loc, interMode, resultScaled0, resultScaled1, d1); b.create(loc, resultScaled); }); rewriter.replaceOp(op, sGrid.getResults()); diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/gridsampler.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/gridsampler.py index 7b1699c64c77..73ed06c46877 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/gridsampler.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/gridsampler.py @@ -115,3 +115,41 @@ def GridSamplerBasic3_basic(module, tu: TestUtils): [[[[-0.3498, -0.8196], [-0.2127, 0.2138], [-0.6515, -0.0513]]]] ).type(torch.FloatTensor) module.forward(inp, grd) + + +class GridSamplerBasic4(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [None, ([1, 1, 4, 4], torch.float32, True), ([1, 1, 3, 2], torch.float32, True)] + ) + def forward(self, x, g): + interpolation_mode = (1,) + padding_mode = (0,) + align_corners = (False,) + tRes = torch.ops.aten.grid_sampler( + x, g, interpolation_mode[0], padding_mode[0], align_corners[0] + ) + return tRes + + +@register_test_case(module_factory=lambda: GridSamplerBasic4()) +def GridSamplerBasic4_basic(module, tu: TestUtils): + inp = torch.tensor( + [ + [ + [ + [0.4963, 0.7682, 0.0885, 0.1320], + [0.3074, 0.6341, 0.4901, 0.8964], + [0.4556, 0.6323, 0.3489, 0.4017], + [0.0223, 0.1689, 0.2939, 0.5185], + ] + ] + ] + ).type(torch.FloatTensor) + grd = torch.tensor( + [[[[-0.3498, -0.8196], [-0.2127, 0.2138], [-0.6515, -0.0513]]]] + ).type(torch.FloatTensor) + module.forward(inp, grd) diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index e8266c04ffad..a87ec4f8f43f 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -1324,6 +1324,85 @@ func.func @test_concat_3d_axis_negative_3(%arg0: !torch.vtensor<[2,2,2],f32>, %a // ----- +// CHECK-LABEL: @test_cumsum_1d_exclusive +func.func @test_cumsum_1d_exclusive(%arg0: !torch.vtensor<[5],f64>, %arg1: !torch.vtensor<[],si32>) -> !torch.vtensor<[5],f64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[RANK:.*]] = torch.constant.int 1 + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[AXIS:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si32> -> !torch.int + // CHECK: %[[BOOL:.*]] = torch.aten.lt.int %[[AXIS]], %[[C0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[INT:.*]] = torch.aten.Int.bool %[[BOOL]] : !torch.bool -> !torch.int + // CHECK: %[[OTHER:.*]] = torch.aten.mul.int %[[INT]], %[[RANK]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[AXIS]], %[[OTHER]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[CUMSUM:.*]] = torch.aten.cumsum %arg0, %[[ADD]], %[[NONE]] : !torch.vtensor<[5],f64>, !torch.int, !torch.none -> !torch.vtensor<[5],f64> + // CHECK: torch.aten.sub.Tensor %[[CUMSUM]], %arg0, %[[C1]] : !torch.vtensor<[5],f64>, !torch.vtensor<[5],f64>, !torch.int -> !torch.vtensor<[5],f64> + %0 = torch.operator "onnx.CumSum"(%arg0, %arg1) {torch.onnx.exclusive = 1 : si64} : (!torch.vtensor<[5],f64>, !torch.vtensor<[],si32>) -> !torch.vtensor<[5],f64> + return %0 : !torch.vtensor<[5],f64> +} + +// ----- + +// CHECK-LABEL: @test_cumsum_1d_reverse +func.func @test_cumsum_1d_reverse(%arg0: !torch.vtensor<[5],f64>, %arg1: !torch.vtensor<[],si32>) -> !torch.vtensor<[5],f64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[RANK:.*]] = torch.constant.int 1 + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[AXIS:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si32> -> !torch.int + // CHECK: %[[BOOL:.*]] = torch.aten.lt.int %[[AXIS]], %[[C0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[INT:.*]] = torch.aten.Int.bool %[[BOOL]] : !torch.bool -> !torch.int + // CHECK: %[[OTHER:.*]] = torch.aten.mul.int %[[INT]], %[[RANK]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[AXIS]], %[[OTHER]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[ADD]] : (!torch.int) -> !torch.list + // CHECK: %[[FLIP:.*]] = torch.aten.flip %arg0, %[[DIMS]] : !torch.vtensor<[5],f64>, !torch.list -> !torch.vtensor<[5],f64> + // CHECK: %[[CUMSUM:.*]] = torch.aten.cumsum %[[FLIP]], %[[ADD]], %[[NONE]] : !torch.vtensor<[5],f64>, !torch.int, !torch.none -> !torch.vtensor<[5],f64> + // CHECK: torch.aten.flip %[[CUMSUM]], %[[DIMS]] : !torch.vtensor<[5],f64>, !torch.list -> !torch.vtensor<[5],f64> + %0 = torch.operator "onnx.CumSum"(%arg0, %arg1) {torch.onnx.reverse = 1 : si64} : (!torch.vtensor<[5],f64>, !torch.vtensor<[],si32>) -> !torch.vtensor<[5],f64> + return %0 : !torch.vtensor<[5],f64> +} + +// ----- + +// CHECK-LABEL: @test_cumsum_1d_reverse_exclusive +func.func @test_cumsum_1d_reverse_exclusive(%arg0: !torch.vtensor<[5],f64>, %arg1: !torch.vtensor<[],si32>) -> !torch.vtensor<[5],f64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[RANK:.*]] = torch.constant.int 1 + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[AXIS:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si32> -> !torch.int + // CHECK: %[[BOOL:.*]] = torch.aten.lt.int %[[AXIS]], %[[C0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[INT:.*]] = torch.aten.Int.bool %[[BOOL]] : !torch.bool -> !torch.int + // CHECK: %[[OTHER:.*]] = torch.aten.mul.int %[[INT]], %[[RANK]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[AXIS]], %[[OTHER]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[ADD]] : (!torch.int) -> !torch.list + // CHECK: %[[FLIP:.*]] = torch.aten.flip %arg0, %[[DIMS]] : !torch.vtensor<[5],f64>, !torch.list -> !torch.vtensor<[5],f64> + // CHECK: %[[CUMSUM:.*]] = torch.aten.cumsum %[[FLIP]], %[[ADD]], %[[NONE]] : !torch.vtensor<[5],f64>, !torch.int, !torch.none -> !torch.vtensor<[5],f64> + // CHECK: %[[FLIP_0:.*]] = torch.aten.flip %[[CUMSUM]], %[[DIMS]] : !torch.vtensor<[5],f64>, !torch.list -> !torch.vtensor<[5],f64> + // CHECK: torch.aten.sub.Tensor %[[FLIP_0]], %arg0, %[[C1]] : !torch.vtensor<[5],f64>, !torch.vtensor<[5],f64>, !torch.int -> !torch.vtensor<[5],f64> + %0 = torch.operator "onnx.CumSum"(%arg0, %arg1) {torch.onnx.exclusive = 1 : si64, torch.onnx.reverse = 1 : si64} : (!torch.vtensor<[5],f64>, !torch.vtensor<[],si32>) -> !torch.vtensor<[5],f64> + return %0 : !torch.vtensor<[5],f64> +} + +// ----- + +// CHECK-LABEL: @test_cumsum_2d +func.func @test_cumsum_2d(%arg0: !torch.vtensor<[2,3],f64>, %arg1: !torch.vtensor<[],si32>) -> !torch.vtensor<[2,3],f64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[RANK:.*]] = torch.constant.int 2 + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[AXIS:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si32> -> !torch.int + // CHECK: %[[BOOL:.*]] = torch.aten.lt.int %[[AXIS]], %[[C0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[INT:.*]] = torch.aten.Int.bool %[[BOOL]] : !torch.bool -> !torch.int + // CHECK: %[[OTHER:.*]] = torch.aten.mul.int %[[INT]], %[[RANK]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[AXIS]], %[[OTHER]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[NONE:.*]] = torch.constant.none + // torch.aten.cumsum %arg0, %[[ADD]], %[[NONE]] : !torch.vtensor<[2,3],f64>, !torch.int, !torch.none -> !torch.vtensor<[2,3],f64> + %0 = torch.operator "onnx.CumSum"(%arg0, %arg1) : (!torch.vtensor<[2,3],f64>, !torch.vtensor<[],si32>) -> !torch.vtensor<[2,3],f64> + return %0 : !torch.vtensor<[2,3],f64> +} + +// ----- + // CHECK-LABEL: func.func @test_exp func.func @test_exp(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64} { // CHECK: torch.aten.exp %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32> diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 991a7075c863..c0f93864f9ee 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -545,7 +545,7 @@ func.func @test_gelu_tanh_2(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtenso // CHECK: %[[B0:.*]] = torch.constant.bool false // CHECK: %[[A0:.*]] = torch.aten.grid_sampler %arg0, %arg1, %[[INT0]], %[[INT0_0]], %[[B0]] : !torch.vtensor<[5,10,10,4],f32> func.func @test_grid_sampler01(%arg0: !torch.vtensor<[5,10,10,4],f32>, %arg1: !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - %0 = torch.operator "onnx.GridSample" (%arg0, %arg1) {torch.onnx.align_corners = 0 : si64, torch.onnx.padding_mode = "zeros"} : (!torch.vtensor<[5,10,10,4],f32>, !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> + %0 = torch.operator "onnx.GridSample" (%arg0, %arg1) {torch.onnx.align_corners = 0 : si64, torch.onnx.mode = "linear", torch.onnx.padding_mode = "zeros"} : (!torch.vtensor<[5,10,10,4],f32>, !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> return %0 : !torch.vtensor<[?,?,?,?],f32> } @@ -563,6 +563,18 @@ func.func @test_grid_sampler02(%arg0: !torch.vtensor<[5,10,10,4],f32>, %arg1: !t // ----- +// CHECK-LABEL: @test_grid_sampler03 +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[B0:.*]] = torch.constant.bool true +// CHECK: %[[A0:.*]] = torch.aten.grid_sampler %arg0, %arg1, %[[INT1]], %[[INT0]], %[[B0]] : !torch.vtensor<[5,10,10,4],f32> +func.func @test_grid_sampler03(%arg0: !torch.vtensor<[5,10,10,4],f32>, %arg1: !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %0 = torch.operator "onnx.GridSample" (%arg0, %arg1) {torch.onnx.align_corners = 1 : si64, torch.onnx.mode = "nearest", torch.onnx.padding_mode = "zeros"} : (!torch.vtensor<[5,10,10,4],f32>, !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> + return %0 : !torch.vtensor<[?,?,?,?],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_less_or_equal func.func @test_less_or_equal(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4,5],f32> diff --git a/test/Conversion/TorchOnnxToTorch/unsupported_fb_opt_ops.mlir b/test/Conversion/TorchOnnxToTorch/unsupported_fb_opt_ops.mlir index b4f9dfbb30f2..3fc02201748e 100644 --- a/test/Conversion/TorchOnnxToTorch/unsupported_fb_opt_ops.mlir +++ b/test/Conversion/TorchOnnxToTorch/unsupported_fb_opt_ops.mlir @@ -37,12 +37,3 @@ func.func @reduce_mean_operation(%arg0: !torch.vtensor<[1,64,768],f32>) %211 = torch.operator "onnx.ReduceMean"(%arg0) {torch.onnx.axes = [-1 : si64]} : (!torch.vtensor<[1,64,768],f32>) -> !torch.vtensor<[1,64,1],f32> return %211 : !torch.vtensor<[1,64,1],f32> } - -// ----- -// Fixed. -func.func @cumsum_operation(%arg0: !torch.vtensor<[2,3],f64>, - %arg1: !torch.vtensor<[],si32>) - -> !torch.vtensor<[2,3],f64> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - %212 = torch.operator "onnx.CumSum"(%arg0, %arg1) : (!torch.vtensor<[2,3],f64>, !torch.vtensor<[],si32>) -> !torch.vtensor<[2,3],f64> - return %212 : !torch.vtensor<[2,3],f64> -} diff --git a/test/Conversion/TorchToLinalg/gridsampler.mlir b/test/Conversion/TorchToLinalg/gridsampler.mlir index 456cfc934471..7c099c5ce4f6 100644 --- a/test/Conversion/TorchToLinalg/gridsampler.mlir +++ b/test/Conversion/TorchToLinalg/gridsampler.mlir @@ -2,27 +2,32 @@ // CHECK: #map // CHECK-LABEL: func @grid_sampler -// CHECK-DAG: %[[TC0:.*]] = torch_c.to_builtin_tensor %[[ARG0:.*]] : !torch.vtensor<[4,10,10,4],f32> -> tensor<4x10x10x4xf32> -// CHECK-DAG: %[[TC1:.*]] = torch_c.to_builtin_tensor %[[ARG1:.*]] : !torch.vtensor<[4,6,8,2],f32> -> tensor<4x6x8x2xf32> -// CHECK-DAG: %[[FALSE:.*]] = torch.constant.bool false -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK-DAG: %[[CST1:.*]] = arith.constant 1.000000e+00 : f32 -// CHECK-DAG: %[[CST2:.*]] = arith.constant 2.000000e+00 : f32 -// CHECK-DAG: %[[C2_3:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[DIM:.*]] = tensor.dim %[[TC0]], %[[C2_3]] : tensor<4x10x10x4xf32> -// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index -// CHECK-DAG: %[[DIM_4:.*]] = tensor.dim %[[TC0]], %[[C3]] : tensor<4x10x10x4xf32> -// CHECK-DAG: %[[X2:.*]] = arith.subi %[[DIM:.*]], %[[C1]] : index -// CHECK-DAG: %[[X3:.*]] = arith.subi %[[DIM_4]], %[[C1:.*]] : index -// CHECK-DAG: %[[X4:.*]] = arith.index_cast %[[X2]] : index to i64 -// CHECK-DAG: %[[X5:.*]] = arith.index_cast %[[X3]] : index to i64 -// CHECK-DAG: %[[X6:.*]] = arith.sitofp %[[X4]] : i64 to f32 -// CHECK-DAG: %[[X7:.*]] = arith.sitofp %[[X5]] : i64 to f32 -// CHECK-DAG: %[[X8:.*]] = arith.divf %[[X6]], %[[CST2]] : f32 -// CHECK-DAG: %[[X9:.*]] = arith.divf %[[X7]], %[[CST2]] : f32 +// CHECK-DAG: %[[TC0:.*]] = torch_c.to_builtin_tensor %[[ARG0:.*]] : !torch.vtensor<[4,10,10,4],f32> -> tensor<4x10x10x4xf32> +// CHECK-DAG: %[[TC1:.*]] = torch_c.to_builtin_tensor %[[ARG1:.*]] : !torch.vtensor<[4,6,8,2],f32> -> tensor<4x6x8x2xf32> +// CHECK-DAG: %[[FALSE:.*]] = torch.constant.bool false +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[CST1:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: %[[CST2:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK-DAG: %[[C2_3:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[DIM:.*]] = tensor.dim %[[TC0]], %[[C2_3]] : tensor<4x10x10x4xf32> +// CHECK-DAG: %[[X73:.*]] = arith.cmpi eq, %[[X3:.*]], %[[C27:.*]] : i64 +// CHECK-DAG: %[[X74:.*]] = arith.select %[[X73:.*]], %[[X70:.*]], %[[X72:.*]] : f32 +// CHECK-DAG: %[[X75:.*]] = arith.subf %[[Xcst_1:.*]], %[[X57:.*]] : f32 +// CHECK-DAG: %[[X76:.*]] = arith.mulf %[[X66:.*]], %[[X75:.*]] : f32 +// CHECK-DAG: %[[X77:.*]] = arith.mulf %[[X74:.*]], %[[X57:.*]] : f32 +// CHECK-DAG: %[[X78:.*]] = arith.addf %[[X76:.*]], %[[X77:.*]] : f32 +// CHECK-DAG: %[[C28:.*]] = arith.constant 5.000000e-01 : f32 +// CHECK-DAG: %[[X79:.*]] = arith.cmpf olt, %[[X57:.*]], %[[X28:.*]] : f32 +// CHECK-DAG: %[[X80:.*]] = arith.select %[[X79:.*]], %[[X66:.*]], %[[X74:.*]] : f32 +// CHECK-DAG: %[[C29:.*]] = arith.constant 0 : i64 +// CHECK-DAG: %[[X81:.*]] = arith.cmpi eq, %[[X3:.*]], %[[C29:.*]] : i64 +// CHECK-DAG: %[[X82:.*]] = arith.select %[[X81:.*]], %[[X78:.*]], %[[X80:.*]] : f32 +// CHECK-DAG: linalg.yield %[[X82:.*]] : f32 +// CHECK-DAG: %[[X14:.*]] = torch_c.from_builtin_tensor %[[X13:.*]] : tensor -> !torch.vtensor<[?,?,?,?],f32> + func.func @grid_sampler(%arg0: !torch.vtensor<[4,10,10,4],f32>, %arg1: !torch.vtensor<[4,6,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> { %true = torch.constant.bool 0 %int0 = torch.constant.int 0 @@ -35,22 +40,25 @@ func.func @grid_sampler(%arg0: !torch.vtensor<[4,10,10,4],f32>, %arg1: !torch.vt // CHECK-LABEL: func @grid_sampler2 // CHECK: #map -// CHECK-DAG: %[[X15:.*]] = arith.mulf %[[X13:.*]], %[[X8:.*]] : f32 -// CHECK-DAG: %[[X16:.*]] = arith.mulf %[[X14:.*]], %[[X9:.*]] : f32 -// CHECK-DAG: %[[X40:.*]] = arith.mulf %[[EXTRACTED:.*]], %[[X39:.*]] : f32 -// CHECK-DAG: %[[X41:.*]] = arith.mulf %[[X31:.*]], %[[X37:.*]] : f32 -// CHECK-DAG: %[[X42:.*]] = arith.addf %[[X40:.*]], %[[X41]] : f32 -// CHECK-DAG: %[[X43:.*]] = arith.subf %[[CST_1:.*]], %[[X37]] : f32 -// CHECK-DAG: %[[X45:.*]] = arith.mulf %[[X34:.*]], %[[X37]] : f32 -// CHECK-DAG: %[[X46:.*]] = arith.addf %[[X44:.*]], %[[X45]] : f32 -// CHECK-DAG: %[[X47:.*]] = arith.subf %[[CST_1]], %[[X38:.*]] : f32 -// CHECK-DAG: %[[X48:.*]] = arith.mulf %[[X42]], %[[XX47:.*]] : f32 -// CHECK-DAG: %[[X49:.*]] = arith.mulf %[[X46]], %[[XX38:.*]] : f32 -// CHECK-DAG: %[[X50:.*]] = arith.addf %[[X48]], %[[X49]] : f32 -// CHECK-DAG: linalg.yield %[[X50]] : f32 -// CHECK: } -> tensor -// CHECK: %[[X12:.*]] = torch_c.from_builtin_tensor %[[X11:.*]] : tensor -> !torch.vtensor<[?,?,?,?],f32> -// CHECK: return %[[X12]] : !torch.vtensor<[?,?,?,?],f32> +// CHECK-DAG: %[[X70:.*]] = arith.addf %[[X68:.*]], %[[X69:.*]] : f32 +// CHECK-DAG: %[[X29:.*]] = arith.constant 5.000000e-01 : f32 +// CHECK-DAG: %[[X71:.*]] = arith.cmpf olt, %[[X58:.*]], %[[X29:.*]] : f32 +// CHECK-DAG: %[[X72:.*]] = arith.select %[[X71:.*]], %[[X52:.*]], %[[X54:.*]] : f32 +// CHECK-DAG: %[[X30:.*]] = arith.constant 0 : i64 +// CHECK-DAG: %[[X73:.*]] = arith.cmpi eq, %[[X3:.*]], %[[X30:.*]] : i64 +// CHECK-DAG: %[[X74:.*]] = arith.select %[[X73:.*]], %[[X70:.*]], %[[X72:.*]] : f32 +// CHECK-DAG: %[[X75:.*]] = arith.subf %[[X1:.*]], %[[X57:.*]] : f32 +// CHECK-DAG: %[[X76:.*]] = arith.mulf %[[X66:.*]], %[[X75:.*]] : f32 +// CHECK-DAG: %[[X77:.*]] = arith.mulf %[[X74:.*]], %[[X57:.*]] : f32 +// CHECK-DAG: %[[X78:.*]] = arith.addf %[[X76:.*]], %[[X77:.*]] : f32 +// CHECK-DAG: %[[X31:.*]] = arith.constant 5.000000e-01 : f32 +// CHECK-DAG: %[[X79:.*]] = arith.cmpf olt, %[[X57:.*]], %[[X31:.*]] : f32 +// CHECK-DAG: %[[X80:.*]] = arith.select %[[X79:.*]], %[[X66:.*]], %[[X74:.*]] : f32 +// CHECK-DAG: %[[X32:.*]] = arith.constant 0 : i64 +// CHECK-DAG: %[[X81:.*]] = arith.cmpi eq, %[[X3:.*]], %[[X32:.*]] : i64 +// CHECK-DAG: %[[X82:.*]] = arith.select %[[X81:.*]], %[[X78:.*]], %[[X80:.*]] : f32 +// CHECK-DAG: linalg.yield %[[X50:.*]] : f32 +// CHECK: return %[[X12:.*]] : !torch.vtensor<[?,?,?,?],f32> func.func @grid_sampler2(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { %true = torch.constant.bool 0 %int0 = torch.constant.int 0 @@ -64,21 +72,21 @@ func.func @grid_sampler2(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vte // CHECK-LABEL: func @grid_sampler3 // CHECK: #map // CHECK-DAG: %[[X15:.*]] = arith.mulf %[[X13:.*]], %[[X8:.*]] : f32 -// CHECK-DAG: %[[X16:.*]] = arith.mulf %[[X14:.*]], %[[X9:.*]] : f32 -// CHECK-DAG: %[[X40:.*]] = arith.mulf %[[EXTRACTED:.*]], %[[X39:.*]] : f32 -// CHECK-DAG: %[[X41:.*]] = arith.mulf %[[X31:.*]], %[[X37:.*]] : f32 -// CHECK-DAG: %[[X42:.*]] = arith.addf %[[X40:.*]], %[[X41]] : f32 -// CHECK-DAG: %[[X43:.*]] = arith.subf %[[CST_1:.*]], %[[X37]] : f32 -// CHECK-DAG: %[[X45:.*]] = arith.mulf %[[X34:.*]], %[[X37]] : f32 -// CHECK-DAG: %[[X46:.*]] = arith.addf %[[X44:.*]], %[[X45]] : f32 -// CHECK-DAG: %[[X47:.*]] = arith.subf %[[CST_1]], %[[X38:.*]] : f32 -// CHECK-DAG: %[[X48:.*]] = arith.mulf %[[X42]], %[[XX47:.*]] : f32 -// CHECK-DAG: %[[X49:.*]] = arith.mulf %[[X46]], %[[XX38:.*]] : f32 -// CHECK-DAG: %[[X50:.*]] = arith.addf %[[X48]], %[[X49]] : f32 -// CHECK-DAG: linalg.yield %[[X50]] : f32 -// CHECK: } -> tensor -// CHECK: %[[X12:.*]] = torch_c.from_builtin_tensor %[[X11:.*]] : tensor -> !torch.vtensor<[?,?,?,?],f32> -// CHECK: return %[[X12]] : !torch.vtensor<[?,?,?,?],f32> +// CHECK-DAG: %[[Y60:.*]] = arith.mulf %[[X48:.*]], %[[X59:.*]] : f32 +// CHECK-DAG: %[[Y61:.*]] = arith.mulf %[[X50:.*]], %[[X58:.*]] : f32 +// CHECK-DAG: %[[Y62:.*]] = arith.addf %[[X60:.*]], %[[X61:.*]] : f32 +// CHECK-DAG: %[[Y28:.*]] = arith.constant 5.000000e-01 : f32 +// CHECK-DAG: %[[Y64:.*]] = arith.select %[[X63:.*]], %[[X48:.*]], %[[X50:.*]] : f32 +// CHECK-DAG: %[[Y29:.*]] = arith.constant 0 : i6 +// CHECK-DAG: %[[Y65:.*]] = arith.cmpi eq, %[[X3:.*]], %[[X28:.*]] : i64 +// CHECK-DAG: %[[Y66:.*]] = arith.select %[[X65:.*]], %[[X62:.*]], %[[X64:.*]] : f32 +// CHECK-DAG: %[[Y67:.*]] = arith.subf %[[X1:.*]], %[[X58:.*]] : f32 +// CHECK-DAG: %[[Y68:.*]] = arith.mulf %[[X52:.*]], %[[X67:.*]] : f32 +// CHECK-DAG: %[[Y69:.*]] = arith.mulf %[[X54:.*]], %[[X58:.*]] : f32 +// CHECK-DAG: %[[Y70:.*]] = arith.addf %[[X68:.*]], %[[X69:.*]] : f32 +// CHECK-DAG: %[[Y30:.*]] = arith.constant 5.000000e-01 : f32 +// CHECK-DAG: %[[Y31:.*]] = arith.constant 0 : i64 +// CHECK: return %[[X12:.*]] : !torch.vtensor<[?,?,?,?],f32> func.func @grid_sampler3(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { %false = torch.constant.bool 1 %int0 = torch.constant.int 0 @@ -86,3 +94,14 @@ func.func @grid_sampler3(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vte %4 = torch.aten.grid_sampler %arg0, %arg1, %int0, %int1, %false : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],f32>, !torch.int, !torch.int, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> return %4 : !torch.vtensor<[?,?,?,?],f32> } + +// ----- + +// CHECK-LABEL: func @grid_sampler4 +func.func @grid_sampler4(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { + %false = torch.constant.bool 1 + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %4 = torch.aten.grid_sampler %arg0, %arg1, %int0, %int1, %false : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],f32>, !torch.int, !torch.int, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + return %4 : !torch.vtensor<[?,?,?,?],f32> +}