diff --git a/docs/Dialects/onnx.md b/docs/Dialects/onnx.md index 3996ad35d6..f865c70815 100644 --- a/docs/Dialects/onnx.md +++ b/docs/Dialects/onnx.md @@ -3531,6 +3531,57 @@ Effects: `MemoryEffects::Effect{}` _ONNX GridSample operation_ +Given an input `X` and a flow-field `grid`, computes the output `Y` using `X` values and pixel locations from the `grid`. +For spatial input `X` with shape (N, C, H, W), the `grid` will have shape (N, H_out, W_out, 2), +the output `Y` will have shape (N, C, H_out, W_out). For volumetric input `X` with shape (N, C, D, H, W), +the `grid` will have shape (N, D_out, H_out, W_out, 3), the output `Y` will have shape (N, C, D_out, H_out, W_out). +More generally, for an input `X` of rank r+2 with shape (N, C, d1, d2, ..., dr), +the `grid` will have shape (N, D1_out, D2_out, ..., Dr_out, r), the output `Y` will have shape (N, C, D1_out, D2_out, ..., Dr_out). + +The tensor `X` contains values at centers of square pixels (voxels, etc) locations such as (n, c, d1_in, d2_in, ..., dr_in). +The (n, d1_out, d2_out, ..., dr_out, :) values from the tensor `grid` are the normalized positions for interpolating the values +at the (n, c, d1_out, d2_out, ..., dr_out) locations from the output tensor `Y` using a specified interpolation method (the mode) +and a padding mode (for `grid` positions falling outside the 2-dimensional image). + +For example, the values in `grid[n, h_out, w_out, :]` are size-2 vectors specifying normalized positions in the 2-dimensional space of `X`. +They are used to interpolate output values of `Y[n, c, h_out, w_out]`. + +The GridSample operator is often used in doing grid generator and sampler in the +[Spatial Transformer Networks](https://arxiv.org/abs/1506.02025). +See also in [torch.nn.functional.grid_sample](https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html). + +Traits: `AlwaysSpeculatableImplTrait` + +Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` + +Effects: `MemoryEffects::Effect{}` + +#### Attributes: + + + + + + +
AttributeMLIR TypeDescription
align_corners::mlir::IntegerAttr64-bit signed integer attribute
mode::mlir::StringAttrstring attribute
padding_mode::mlir::StringAttrstring attribute
+ +#### Operands: + +| Operand | Description | +| :-----: | ----------- | +| `X` | tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of string type values or tensor of 1-bit signless integer values or tensor of complex type with 32-bit float elements values or tensor of complex type with 64-bit float elements values +| `grid` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values + +#### Results: + +| Result | Description | +| :----: | ----------- | +| `Y` | tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of string type values or tensor of 1-bit signless integer values or tensor of complex type with 32-bit float elements values or tensor of complex type with 64-bit float elements values + +### `onnx.GridSampleV16` (ONNXGridSampleV16Op) + +_ONNX GridSample operation_ + Given an input `X` and a flow-field `grid`, computes the output `Y` using `X` values and pixel locations from `grid`. Currently, only spatial (4-D) inputs are supported. For input `X` with shape (N, C, H, W) and `grid` with shape (N, H_out, W_out, 2), the output `Y` will have shape (N, C, H_out, W_out). diff --git a/src/Builder/OpBuildTable.inc b/src/Builder/OpBuildTable.inc index 067839b22f..0bb9ce1190 100644 --- a/src/Builder/OpBuildTable.inc +++ b/src/Builder/OpBuildTable.inc @@ -79,7 +79,7 @@ op_dialect_version_map_["GlobalMaxPool"] = {1}; op_dialect_version_map_["Gradient"] = {1}; op_dialect_version_map_["Greater"] = {13}; op_dialect_version_map_["GreaterOrEqual"] = {16}; -op_dialect_version_map_["GridSample"] = {16}; +op_dialect_version_map_["GridSample"] = {22, 16}; op_dialect_version_map_["GroupNormalization"] = {21, 18}; op_dialect_version_map_["HammingWindow"] = {17}; op_dialect_version_map_["HannWindow"] = {17}; @@ -356,6 +356,8 @@ import_handler_map_["GreaterOrEqual"] = &onnx_mlir::detail::FrontendGenImpl::buildOperation; import_handler_map_["GridSample"] = &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["GridSampleV16"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; import_handler_map_["GroupNormalization"] = &onnx_mlir::detail::FrontendGenImpl::buildOperation; import_handler_map_["GroupNormalizationV18"] = diff --git a/src/Dialect/ONNX/CMakeLists.txt b/src/Dialect/ONNX/CMakeLists.txt index 3917c94aa4..b68e1cf247 100644 --- a/src/Dialect/ONNX/CMakeLists.txt +++ b/src/Dialect/ONNX/CMakeLists.txt @@ -80,6 +80,7 @@ add_onnx_mlir_library(OMONNXOps ONNXOps/Tensor/Gather.cpp ONNXOps/Tensor/GatherElements.cpp ONNXOps/Tensor/GatherND.cpp + ONNXOps/Tensor/GridSample.cpp ONNXOps/Tensor/Identity.cpp ONNXOps/Tensor/NonZero.cpp ONNXOps/Tensor/OneHot.cpp diff --git a/src/Dialect/ONNX/ONNXOps.td.inc b/src/Dialect/ONNX/ONNXOps.td.inc index 685c5438be..021dc53194 100644 --- a/src/Dialect/ONNX/ONNXOps.td.inc +++ b/src/Dialect/ONNX/ONNXOps.td.inc @@ -3062,6 +3062,57 @@ def ONNXGreaterOrEqualOp:ONNX_Op<"GreaterOrEqual", } def ONNXGridSampleOp:ONNX_Op<"GridSample", + [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "ONNX GridSample operation"; + let description = [{ + Given an input `X` and a flow-field `grid`, computes the output `Y` using `X` values and pixel locations from the `grid`. + For spatial input `X` with shape (N, C, H, W), the `grid` will have shape (N, H_out, W_out, 2), + the output `Y` will have shape (N, C, H_out, W_out). For volumetric input `X` with shape (N, C, D, H, W), + the `grid` will have shape (N, D_out, H_out, W_out, 3), the output `Y` will have shape (N, C, D_out, H_out, W_out). + More generally, for an input `X` of rank r+2 with shape (N, C, d1, d2, ..., dr), + the `grid` will have shape (N, D1_out, D2_out, ..., Dr_out, r), the output `Y` will have shape (N, C, D1_out, D2_out, ..., Dr_out). + + The tensor `X` contains values at centers of square pixels (voxels, etc) locations such as (n, c, d1_in, d2_in, ..., dr_in). + The (n, d1_out, d2_out, ..., dr_out, :) values from the tensor `grid` are the normalized positions for interpolating the values + at the (n, c, d1_out, d2_out, ..., dr_out) locations from the output tensor `Y` using a specified interpolation method (the mode) + and a padding mode (for `grid` positions falling outside the 2-dimensional image). + + For example, the values in `grid[n, h_out, w_out, :]` are size-2 vectors specifying normalized positions in the 2-dimensional space of `X`. + They are used to interpolate output values of `Y[n, c, h_out, w_out]`. + + The GridSample operator is often used in doing grid generator and sampler in the + [Spatial Transformer Networks](https://arxiv.org/abs/1506.02025). + See also in [torch.nn.functional.grid_sample](https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html). + }]; + let arguments = (ins AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex]>, TensorOf<[Complex]>]>:$X, + AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$grid, + DefaultValuedAttr:$align_corners, + DefaultValuedStrAttr:$mode, + DefaultValuedStrAttr:$padding_mode); + let results = (outs AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex]>, TensorOf<[Complex]>]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 2; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {30}; + } + }]; + let extraClassDefinition = [{ + onnx_mlir::ONNXOpShapeHelper * $cppClass::getShapeHelper(mlir::Operation *op, llvm::ArrayRef oper, + onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { + onnx_mlir::ONNXOpShapeHelper *sh = new onnx_mlir::ONNXGridSampleOpShapeHelper(op, oper, ieb, scope); + assert(sh && "failed to allocate shape helper"); + return sh; + } + }]; + let hasVerifier = 1; +} + +def ONNXGridSampleV16Op:ONNX_Op<"GridSampleV16", [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX GridSample operation"; let description = [{ @@ -3099,7 +3150,7 @@ def ONNXGridSampleOp:ONNX_Op<"GridSample", let extraClassDefinition = [{ onnx_mlir::ONNXOpShapeHelper * $cppClass::getShapeHelper(mlir::Operation *op, llvm::ArrayRef oper, onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { - onnx_mlir::ONNXOpShapeHelper *sh = new onnx_mlir::ONNXGridSampleOpShapeHelper(op, oper, ieb, scope); + onnx_mlir::ONNXOpShapeHelper *sh = new onnx_mlir::ONNXGridSampleV16OpShapeHelper(op, oper, ieb, scope); assert(sh && "failed to allocate shape helper"); return sh; } diff --git a/src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp b/src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp index 01a8943ead..80be997fda 100644 --- a/src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp +++ b/src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp @@ -868,6 +868,7 @@ using ONNXDimOpShapeHelper = ONNXNonSpecificOpShapeHelper; using ONNXDropoutOpShapeHelper = ONNXNonSpecificOpShapeHelper; using ONNXDynamicQuantizeLinearOpShapeHelper = ONNXNonSpecificOpShapeHelper; using ONNXEinsumOpShapeHelper = ONNXNonSpecificOpShapeHelper; +using ONNXGridSampleOpShapeHelper = ONNXNonSpecificOpShapeHelper; using ONNXEyeLikeOpShapeHelper = ONNXNonSpecificOpShapeHelper; using ONNXFlattenOpShapeHelper = ONNXNonSpecificOpShapeHelper; using ONNXGatherElementsOpShapeHelper = ONNXNonSpecificOpShapeHelper; diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/GridSample.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/GridSample.cpp new file mode 100644 index 0000000000..c281ca296e --- /dev/null +++ b/src/Dialect/ONNX/ONNXOps/Tensor/GridSample.cpp @@ -0,0 +1,128 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===------------------ GridSample.cpp - ONNX Operations ------------------===// +// +// Copyright (c) 2024 Advanced Micro Devices, Inc. +// +// ============================================================================= +// +// This file provides definition of ONNX dialect GridSample operation. +// +//===----------------------------------------------------------------------===// + +#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp" + +using namespace mlir; +using namespace mlir::OpTrait::util; +using namespace onnx_mlir; + +//===----------------------------------------------------------------------===// +// Support +//===----------------------------------------------------------------------===// + +namespace onnx_mlir { + +template <> +LogicalResult ONNXGridSampleOpShapeHelper::computeShape() { + + // Read data and indices shapes as dim indices. + ONNXGridSampleOpAdaptor operandAdaptor(operands); + DimsExpr inputDims; + DimsExpr gridDims; + createIE->getShapeAsDims(operandAdaptor.getX(), inputDims); + createIE->getShapeAsDims(operandAdaptor.getGrid(), gridDims); + + int64_t gridRank = gridDims.size(); + + // Input's dimensions of rank r+2 should be in the form of (N,C,D1,D2,...,Dr) + // Grid's dimensions should also have rank r+2 and be in the form of + // (N,D1_out,D2_out,...,Dr_out,r). + // The output Y will have shape (N, C, D1_out, D2_out, ..., Dr_out). + DimsExpr outputDims; + outputDims.emplace_back(inputDims[0]); + outputDims.emplace_back(inputDims[1]); + for (int i = 1; i < gridRank - 1; ++i) { + outputDims.emplace_back(gridDims[i]); + } + + setOutputDims(outputDims); + return success(); +} + +} // namespace onnx_mlir + +//===----------------------------------------------------------------------===// +// Verify +//===----------------------------------------------------------------------===// + +LogicalResult ONNXGridSampleOp::verify() { + ONNXGridSampleOpAdaptor operandAdaptor(*this); + auto op = mlir::cast(*this); + + const auto alignCorners = op.getAlignCorners(); + if (alignCorners != 0 && alignCorners != 1) { + return emitOpError("align_corners needs to be 0 or 1"); + } + const auto mode = op.getMode(); + if (mode != "linear" && mode != "nearest" && mode != "cubic") { + return emitOpError("mode needs to be linear, nearest or cubic"); + } + const auto paddingMode = op.getPaddingMode(); + if (paddingMode != "zeros" && paddingMode != "border" && + paddingMode != "reflection") { + return emitOpError("padding_mode needs to be zeros, border or reflection"); + } + + if (!hasShapeAndRank(getOperation())) + return success(); + + auto inputShape = + mlir::cast(operandAdaptor.getX().getType()).getShape(); + int64_t inputRank = inputShape.size(); + auto gridShape = + mlir::cast(operandAdaptor.getGrid().getType()).getShape(); + + // Check whether the ranks of input and grid are valid and are equal. + // Input's dimensions of rank r+2 should be in the form of (N,C,D1,D2,...,Dr) + // Grid's dimensions should also have rank r+2 and be in the form of + // (N,D1_out,D2_out,...,Dr_out,r). + if (inputShape.size() != gridShape.size()) { + return emitOpError() << "Input(=" << inputShape.size() + << ") and grid(=" << gridShape.size() + << ") have different dim sizes."; + } + + if (inputShape[0] != gridShape[0]) { + return emitOpError() << "Input and grid must have the same batch value."; + } + + if (!ShapedType::isDynamic(gridShape.back()) && + gridShape.back() != inputRank - 2) { + return emitOpError() << "Grid last dim must have been '" << inputRank - 2 + << "' instead of '" << gridShape.back() << "'."; + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// Shape Inference +//===----------------------------------------------------------------------===// + +LogicalResult ONNXGridSampleOp::inferShapes( + std::function /*doShapeInference*/) { + + Type elementType = mlir::cast(getX().getType()).getElementType(); + ONNXGridSampleOpShapeHelper shapeHelper(getOperation(), {}); + return shapeHelper.computeShapeAndUpdateType(elementType); +} + +//===----------------------------------------------------------------------===// +// Template instantiation +//===----------------------------------------------------------------------===// + +namespace onnx_mlir { +template struct ONNXNonSpecificOpShapeHelper; +} // namespace onnx_mlir diff --git a/src/Dialect/ONNX/ONNXUnsupportedOps.hpp b/src/Dialect/ONNX/ONNXUnsupportedOps.hpp index 8a43b3e4a1..c32d07c672 100644 --- a/src/Dialect/ONNX/ONNXUnsupportedOps.hpp +++ b/src/Dialect/ONNX/ONNXUnsupportedOps.hpp @@ -37,7 +37,6 @@ UNSUPPORTED_OPS(ONNXDeformConvOp) UNSUPPORTED_OPS(ONNXDictVectorizerOp) UNSUPPORTED_OPS(ONNXFeatureVectorizerOp) UNSUPPORTED_OPS(ONNXGradientOp) -UNSUPPORTED_OPS(ONNXGridSampleOp) UNSUPPORTED_OPS(ONNXHammingWindowOp) UNSUPPORTED_OPS(ONNXHannWindowOp) UNSUPPORTED_OPS(ONNXImputerOp) @@ -76,6 +75,7 @@ CONVERTED_TO_SUPPORTED_OPS(ONNXClipV11Op) CONVERTED_TO_SUPPORTED_OPS(ONNXClipV12Op) CONVERTED_TO_SUPPORTED_OPS(ONNXClipV6Op) CONVERTED_TO_SUPPORTED_OPS(ONNXDFTV17Op) +CONVERTED_TO_SUPPORTED_OPS(ONNXGridSampleV16Op) CONVERTED_TO_SUPPORTED_OPS(ONNXGroupNormalizationOp) CONVERTED_TO_SUPPORTED_OPS(ONNXGroupNormalizationV18Op) CONVERTED_TO_SUPPORTED_OPS(ONNXPadV18Op) diff --git a/src/Dialect/ONNX/Transforms/Decompose.cpp b/src/Dialect/ONNX/Transforms/Decompose.cpp index ed9e0fb150..d61a980e15 100644 --- a/src/Dialect/ONNX/Transforms/Decompose.cpp +++ b/src/Dialect/ONNX/Transforms/Decompose.cpp @@ -387,6 +387,18 @@ bool canSequenceAtBeReplaced(Value sequenceAtResult) { return true; } +Attribute upgradeGridSampleV16Mode(PatternRewriter &rewriter, Attribute mode) { + const auto stringMode = mlir::cast(mode); + if (stringMode.strref() == "bilinear") { + return rewriter.getStringAttr("linear"); + } + if (stringMode.strref() == "bicubic") { + return rewriter.getStringAttr("cubic"); + } + assert(stringMode.strref() == "nearest"); + return mode; +} + Value replaceSequenceAt( PatternRewriter &rewriter, Location loc, Value sequenceAtResult) { ONNXSequenceAtOp op = sequenceAtResult.getDefiningOp(); @@ -1318,6 +1330,7 @@ void DecomposeONNXToONNXPass::runOnOperation() { target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/src/Dialect/ONNX/Transforms/Decompose.td b/src/Dialect/ONNX/Transforms/Decompose.td index 00ae9f6ff3..bc7044b524 100644 --- a/src/Dialect/ONNX/Transforms/Decompose.td +++ b/src/Dialect/ONNX/Transforms/Decompose.td @@ -73,6 +73,9 @@ def ReshapeElementsAttrToRank0 : NativeCodeCall< def ReplaceSequenceAt : NativeCodeCall< "onnx_mlir::replaceSequenceAt($_builder, $_loc, $0)">; + +def UpgradeGridSampleV16Mode : NativeCodeCall< + "onnx_mlir::upgradeGridSampleV16Mode($_builder, $0)">; def CanSequenceAtBeReplaced : Constraint, "check whether the SequenceAt can be replaced with split">; @@ -365,6 +368,12 @@ def ClipV12Pattern : Pat< (ONNXClipOp $x, $min, $max) >; +// Rewrite GridSample 16 to GridSample 22 +def GridSampleV16Pattern : Pat< + (ONNXGridSampleV16Op $x, $grid, $align_corners, $mode, $padding_mode), + (ONNXGridSampleOp $x, $grid, $align_corners, (UpgradeGridSampleV16Mode $mode), $padding_mode) +>; + def DFTV17Pattern : Pat< (ONNXDFTV17Op $x, $dft_length, $axis, $inverse, $onesided), (ONNXDFTOp $x, $dft_length, (ONNXConstantOpFromDenseAttr(createScalarDenseAttrRank0 $axis)), $inverse, $onesided) diff --git a/test/mlir/onnx/invalid.mlir b/test/mlir/onnx/invalid.mlir index 7b8d087c03..763de9043a 100644 --- a/test/mlir/onnx/invalid.mlir +++ b/test/mlir/onnx/invalid.mlir @@ -719,3 +719,51 @@ func.func @test_matmulinteger_wrong_B_broadcast(%arg0: tensor<16x32xui8>, %arg1: %0 = "onnx.MatMulInteger"(%arg0, %arg1, %arg2, %arg3) : (tensor<16x32xui8>, tensor<5x32x64xui8>, tensor<16xui8>, tensor<5x1x2xui8>) -> tensor<5x16x64xi32> onnx.Return %0 : tensor<5x16x64xi32> } + +// ----- + +func.func @test_grid_sample_diff_ranks(%arg0: tensor<1x3x1152x1344xf32>, %arg1: tensor<1x1152x2xf32>) -> tensor<*xf32> { + // expected-error @+1 {{'onnx.GridSample' op Input(=4) and grid(=3) have different dim sizes.}} + %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", padding_mode = "border"} : (tensor<1x3x1152x1344xf32>, tensor<1x1152x2xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +func.func @test_grid_sample_diff_batch(%arg0: tensor<1x1x4x4xf32>, %arg1: tensor<2x6x6x2xf32>) -> tensor<*xf32> { + // expected-error @+1 {{'onnx.GridSample' op Input and grid must have the same batch value.}} + %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", padding_mode = "border"} : (tensor<1x1x4x4xf32>, tensor<2x6x6x2xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +func.func @test_grid_sample_align_corners(%arg0: tensor<2x1x4x4xf32>, %arg1: tensor<2x6x6x2xf32>) -> tensor<*xf32> { + // expected-error @+1 {{'onnx.GridSample' op align_corners needs to be 0 or 1}} + %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 2 : si64, mode = "linear", padding_mode = "border"} : (tensor<2x1x4x4xf32>, tensor<2x6x6x2xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +func.func @test_grid_sample_mode(%arg0: tensor<2x1x4x4xf32>, %arg1: tensor<2x6x6x2xf32>) -> tensor<*xf32> { + // expected-error @+1 {{'onnx.GridSample' op mode needs to be linear, nearest or cubic}} + %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "sampling", padding_mode = "border"} : (tensor<2x1x4x4xf32>, tensor<2x6x6x2xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +func.func @test_grid_sample_padding(%arg0: tensor<2x1x4x4xf32>, %arg1: tensor<2x6x6x2xf32>) -> tensor<*xf32> { + // expected-error @+1 {{'onnx.GridSample' op padding_mode needs to be zeros, border or reflection}} + %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "cubic", padding_mode = "bottom"} : (tensor<2x1x4x4xf32>, tensor<2x6x6x2xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +func.func @test_grid_sample_wrong_dim_grid(%arg0: tensor<1x1x4x4xf32>, %arg1: tensor<1x6x6x3xf32>) -> tensor<*xf32> { + // expected-error @+1 {{'onnx.GridSample' op Grid last dim must have been '2' instead of '3'.}} + %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", padding_mode = "border"} : (tensor<1x1x4x4xf32>, tensor<1x6x6x3xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} \ No newline at end of file diff --git a/test/mlir/onnx/onnx_decompose.mlir b/test/mlir/onnx/onnx_decompose.mlir index 0f8ac9f554..5a4bfc02a0 100644 --- a/test/mlir/onnx/onnx_decompose.mlir +++ b/test/mlir/onnx/onnx_decompose.mlir @@ -2,6 +2,42 @@ // ----- +func.func @test_grid_sample_v16_bicubic(%arg0: tensor<2x1x4x4xf32>, %arg1: tensor<2x6x6x2xf32>) -> tensor<*xf32> { + %0 = "onnx.GridSampleV16"(%arg0, %arg1) {align_corners = 1 : si64, mode = "bicubic", padding_mode = "zeros"} : (tensor<2x1x4x4xf32>, tensor<2x6x6x2xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +// CHECK-LABEL: func.func @test_grid_sample_v16_bicubic +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x1x4x4xf32>, [[PARAM_1_:%.+]]: tensor<2x6x6x2xf32>) -> tensor<*xf32> { +// CHECK: [[VAR_0_:%.+]] = "onnx.GridSample"([[PARAM_0_]], [[PARAM_1_]]) {align_corners = 1 : si64, mode = "cubic", padding_mode = "zeros"} : (tensor<2x1x4x4xf32>, tensor<2x6x6x2xf32>) -> tensor<*xf32> +// CHECK: return [[VAR_0_]] : tensor<*xf32> +// CHECK: } +} + +// ----- + +func.func @test_grid_sample_v16_bilinear(%arg0: tensor<2x1x4x4xf32>, %arg1: tensor<2x6x6x2xf32>) -> tensor<*xf32> { + %0 = "onnx.GridSampleV16"(%arg0, %arg1) {align_corners = 1 : si64, mode = "bilinear", padding_mode = "zeros"} : (tensor<2x1x4x4xf32>, tensor<2x6x6x2xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +// CHECK-LABEL: func.func @test_grid_sample_v16_bilinear +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x1x4x4xf32>, [[PARAM_1_:%.+]]: tensor<2x6x6x2xf32>) -> tensor<*xf32> { +// CHECK: [[VAR_0_:%.+]] = "onnx.GridSample"([[PARAM_0_]], [[PARAM_1_]]) {align_corners = 1 : si64, mode = "linear", padding_mode = "zeros"} : (tensor<2x1x4x4xf32>, tensor<2x6x6x2xf32>) -> tensor<*xf32> +// CHECK: return [[VAR_0_]] : tensor<*xf32> +// CHECK: } +} + +// ----- + +func.func @test_grid_sample_v16_nearest(%arg0: tensor<2x1x4x4xf32>, %arg1: tensor<2x6x6x2xf32>) -> tensor<*xf32> { + %0 = "onnx.GridSampleV16"(%arg0, %arg1) {align_corners = 1 : si64, mode = "nearest", padding_mode = "zeros"} : (tensor<2x1x4x4xf32>, tensor<2x6x6x2xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +// CHECK-LABEL: func.func @test_grid_sample_v16_nearest +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x1x4x4xf32>, [[PARAM_1_:%.+]]: tensor<2x6x6x2xf32>) -> tensor<*xf32> { +// CHECK: [[VAR_0_:%.+]] = "onnx.GridSample"([[PARAM_0_]], [[PARAM_1_]]) {align_corners = 1 : si64, mode = "nearest", padding_mode = "zeros"} : (tensor<2x1x4x4xf32>, tensor<2x6x6x2xf32>) -> tensor<*xf32> +// CHECK: return [[VAR_0_]] : tensor<*xf32> +// CHECK: } +} + +// ----- + func.func @test_dft(%arg0 : tensor, %arg1 : tensor) -> tensor<*xf32> { %cst = "onnx.NoValue"() {value} : () -> none %0 ="onnx.DFTV17"(%arg0, %arg1) : (tensor, tensor)-> tensor<*xf32> diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir index 4a625c9405..478ea172b5 100644 --- a/test/mlir/onnx/onnx_shape_inference.mlir +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -3819,3 +3819,75 @@ func.func @test_RMSlayer_norm_2inputs(%arg0: tensor<12x3x5xf32>, %arg1: tensor<5 // CHECK: } } +// ----- + +// Test Grid Sample + +func.func @test_grid_sample_same_dims(%arg0: tensor<1x3x1152x1344xf32>, %arg1: tensor<1x1152x1344x2xf32>) -> tensor<*xf32> { + %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x3x1152x1344xf32>, tensor<1x1152x1344x2xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> + +// mlir2FileCheck.py +// CHECK-LABEL: func.func @test_grid_sample_same_dims +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x3x1152x1344xf32>, [[PARAM_1_:%.+]]: tensor<1x1152x1344x2xf32>) -> tensor<1x3x1152x1344xf32> { +// CHECK: [[GRID:%.+]] = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x3x1152x1344xf32>, tensor<1x1152x1344x2xf32>) -> tensor<1x3x1152x1344xf32> +// CHECK: return [[GRID]] : tensor<1x3x1152x1344xf32> +// CHECK: } +} + +func.func @test_grid_sample_diff_dims(%arg0: tensor<1x1x4x4xf32>, %arg1: tensor<1x6x6x2xf32>) -> tensor<*xf32> { + %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x1x4x4xf32>, tensor<1x6x6x2xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> + +// mlir2FileCheck.py +// CHECK-LABEL: func.func @test_grid_sample_diff_dims +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x4x4xf32>, [[PARAM_1_:%.+]]: tensor<1x6x6x2xf32>) -> tensor<1x1x6x6xf32> { +// CHECK: [[GRID:%.+]] = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x1x4x4xf32>, tensor<1x6x6x2xf32>) -> tensor<1x1x6x6xf32> +// CHECK: return [[GRID]] : tensor<1x1x6x6xf32> +// CHECK: } +} + +func.func @test_grid_sample_6d(%arg0: tensor<1x2x4x4x4x4xf32>, %arg1: tensor<1x6x6x4x4x4xf32>) -> tensor<*xf32> { + %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x2x4x4x4x4xf32>, tensor<1x6x6x4x4x4xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> + +// mlir2FileCheck.py +// CHECK-LABEL: func.func @test_grid_sample_6d +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x2x4x4x4x4xf32>, [[PARAM_1_:%.+]]: tensor<1x6x6x4x4x4xf32>) -> tensor<1x2x6x6x4x4xf32> { +// CHECK: [[GRID:%.+]] = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x2x4x4x4x4xf32>, tensor<1x6x6x4x4x4xf32>) -> tensor<1x2x6x6x4x4xf32> +// CHECK: return [[GRID]] : tensor<1x2x6x6x4x4xf32> +// CHECK: } +} + +func.func @test_grid_sample_dim_shape(%arg0: tensor, %arg1: tensor) -> tensor<*xf32> { + %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor, tensor) -> tensor<*xf32> +// mlir2FileCheck.py +// CHECK-LABEL: func.func @test_grid_sample_dim_shape +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor) -> tensor { +// CHECK: [[GRID:%.+]] = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor, tensor) -> tensor +// CHECK: return [[GRID]] : tensor +// CHECK: } + return %0 : tensor<*xf32> +} + +func.func @test_grid_sample_dim_shape2(%arg0: tensor, %arg1: tensor) -> tensor<*xf32> { + %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor, tensor) -> tensor<*xf32> +// mlir2FileCheck.py +// CHECK-LABEL: func.func @test_grid_sample_dim_shape2 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor) -> tensor { +// CHECK: [[GRID:%.+]] = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor, tensor) -> tensor +// CHECK: return [[GRID]] : tensor +// CHECK: } + return %0 : tensor<*xf32> +} + +func.func @test_grid_sample_dim_shape3(%arg0: tensor, %arg1: tensor) -> tensor<*xf32> { + %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor, tensor) -> tensor<*xf32> +// mlir2FileCheck.py +// CHECK-LABEL: func.func @test_grid_sample_dim_shape3 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor) -> tensor { +// CHECK: [[GRID:%.+]] = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor, tensor) -> tensor +// CHECK: return [[GRID]] : tensor +// CHECK: } + return %0 : tensor<*xf32> +} \ No newline at end of file diff --git a/utils/gen_onnx_mlir.py b/utils/gen_onnx_mlir.py index a32c931521..e1473ad92d 100755 --- a/utils/gen_onnx_mlir.py +++ b/utils/gen_onnx_mlir.py @@ -160,7 +160,7 @@ "Gradient": [1], "Greater": [13], "GreaterOrEqual": [16], - "GridSample": [16], + "GridSample": [22, 16], "GroupNormalization": [21, 18], "HammingWindow": [17], "HannWindow": [17], @@ -396,6 +396,7 @@ "Gelu", "Greater", "GreaterOrEqual", + "GridSample", "GroupNormalizationV18", "Hardmax", "If",