diff --git a/docs/Dialects/onnx.md b/docs/Dialects/onnx.md
index 3996ad35d6..f865c70815 100644
--- a/docs/Dialects/onnx.md
+++ b/docs/Dialects/onnx.md
@@ -3531,6 +3531,57 @@ Effects: `MemoryEffects::Effect{}`
_ONNX GridSample operation_
+Given an input `X` and a flow-field `grid`, computes the output `Y` using `X` values and pixel locations from the `grid`.
+For spatial input `X` with shape (N, C, H, W), the `grid` will have shape (N, H_out, W_out, 2),
+the output `Y` will have shape (N, C, H_out, W_out). For volumetric input `X` with shape (N, C, D, H, W),
+the `grid` will have shape (N, D_out, H_out, W_out, 3), the output `Y` will have shape (N, C, D_out, H_out, W_out).
+More generally, for an input `X` of rank r+2 with shape (N, C, d1, d2, ..., dr),
+the `grid` will have shape (N, D1_out, D2_out, ..., Dr_out, r), the output `Y` will have shape (N, C, D1_out, D2_out, ..., Dr_out).
+
+The tensor `X` contains values at centers of square pixels (voxels, etc) locations such as (n, c, d1_in, d2_in, ..., dr_in).
+The (n, d1_out, d2_out, ..., dr_out, :) values from the tensor `grid` are the normalized positions for interpolating the values
+at the (n, c, d1_out, d2_out, ..., dr_out) locations from the output tensor `Y` using a specified interpolation method (the mode)
+and a padding mode (for `grid` positions falling outside the 2-dimensional image).
+
+For example, the values in `grid[n, h_out, w_out, :]` are size-2 vectors specifying normalized positions in the 2-dimensional space of `X`.
+They are used to interpolate output values of `Y[n, c, h_out, w_out]`.
+
+The GridSample operator is often used in doing grid generator and sampler in the
+[Spatial Transformer Networks](https://arxiv.org/abs/1506.02025).
+See also in [torch.nn.functional.grid_sample](https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html).
+
+Traits: `AlwaysSpeculatableImplTrait`
+
+Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface`
+
+Effects: `MemoryEffects::Effect{}`
+
+#### Attributes:
+
+
+Attribute | MLIR Type | Description |
+align_corners | ::mlir::IntegerAttr | 64-bit signed integer attribute |
+mode | ::mlir::StringAttr | string attribute |
+padding_mode | ::mlir::StringAttr | string attribute |
+
+
+#### Operands:
+
+| Operand | Description |
+| :-----: | ----------- |
+| `X` | tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of string type values or tensor of 1-bit signless integer values or tensor of complex type with 32-bit float elements values or tensor of complex type with 64-bit float elements values
+| `grid` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values
+
+#### Results:
+
+| Result | Description |
+| :----: | ----------- |
+| `Y` | tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of string type values or tensor of 1-bit signless integer values or tensor of complex type with 32-bit float elements values or tensor of complex type with 64-bit float elements values
+
+### `onnx.GridSampleV16` (ONNXGridSampleV16Op)
+
+_ONNX GridSample operation_
+
Given an input `X` and a flow-field `grid`, computes the output `Y` using `X` values and pixel locations from `grid`.
Currently, only spatial (4-D) inputs are supported. For input `X` with shape (N, C, H, W) and `grid` with shape (N, H_out, W_out, 2),
the output `Y` will have shape (N, C, H_out, W_out).
diff --git a/src/Builder/OpBuildTable.inc b/src/Builder/OpBuildTable.inc
index 067839b22f..0bb9ce1190 100644
--- a/src/Builder/OpBuildTable.inc
+++ b/src/Builder/OpBuildTable.inc
@@ -79,7 +79,7 @@ op_dialect_version_map_["GlobalMaxPool"] = {1};
op_dialect_version_map_["Gradient"] = {1};
op_dialect_version_map_["Greater"] = {13};
op_dialect_version_map_["GreaterOrEqual"] = {16};
-op_dialect_version_map_["GridSample"] = {16};
+op_dialect_version_map_["GridSample"] = {22, 16};
op_dialect_version_map_["GroupNormalization"] = {21, 18};
op_dialect_version_map_["HammingWindow"] = {17};
op_dialect_version_map_["HannWindow"] = {17};
@@ -356,6 +356,8 @@ import_handler_map_["GreaterOrEqual"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation;
import_handler_map_["GridSample"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation;
+import_handler_map_["GridSampleV16"] =
+ &onnx_mlir::detail::FrontendGenImpl::buildOperation;
import_handler_map_["GroupNormalization"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation;
import_handler_map_["GroupNormalizationV18"] =
diff --git a/src/Dialect/ONNX/CMakeLists.txt b/src/Dialect/ONNX/CMakeLists.txt
index 3917c94aa4..b68e1cf247 100644
--- a/src/Dialect/ONNX/CMakeLists.txt
+++ b/src/Dialect/ONNX/CMakeLists.txt
@@ -80,6 +80,7 @@ add_onnx_mlir_library(OMONNXOps
ONNXOps/Tensor/Gather.cpp
ONNXOps/Tensor/GatherElements.cpp
ONNXOps/Tensor/GatherND.cpp
+ ONNXOps/Tensor/GridSample.cpp
ONNXOps/Tensor/Identity.cpp
ONNXOps/Tensor/NonZero.cpp
ONNXOps/Tensor/OneHot.cpp
diff --git a/src/Dialect/ONNX/ONNXOps.td.inc b/src/Dialect/ONNX/ONNXOps.td.inc
index 685c5438be..021dc53194 100644
--- a/src/Dialect/ONNX/ONNXOps.td.inc
+++ b/src/Dialect/ONNX/ONNXOps.td.inc
@@ -3062,6 +3062,57 @@ def ONNXGreaterOrEqualOp:ONNX_Op<"GreaterOrEqual",
}
def ONNXGridSampleOp:ONNX_Op<"GridSample",
+ [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> {
+ let summary = "ONNX GridSample operation";
+ let description = [{
+ Given an input `X` and a flow-field `grid`, computes the output `Y` using `X` values and pixel locations from the `grid`.
+ For spatial input `X` with shape (N, C, H, W), the `grid` will have shape (N, H_out, W_out, 2),
+ the output `Y` will have shape (N, C, H_out, W_out). For volumetric input `X` with shape (N, C, D, H, W),
+ the `grid` will have shape (N, D_out, H_out, W_out, 3), the output `Y` will have shape (N, C, D_out, H_out, W_out).
+ More generally, for an input `X` of rank r+2 with shape (N, C, d1, d2, ..., dr),
+ the `grid` will have shape (N, D1_out, D2_out, ..., Dr_out, r), the output `Y` will have shape (N, C, D1_out, D2_out, ..., Dr_out).
+
+ The tensor `X` contains values at centers of square pixels (voxels, etc) locations such as (n, c, d1_in, d2_in, ..., dr_in).
+ The (n, d1_out, d2_out, ..., dr_out, :) values from the tensor `grid` are the normalized positions for interpolating the values
+ at the (n, c, d1_out, d2_out, ..., dr_out) locations from the output tensor `Y` using a specified interpolation method (the mode)
+ and a padding mode (for `grid` positions falling outside the 2-dimensional image).
+
+ For example, the values in `grid[n, h_out, w_out, :]` are size-2 vectors specifying normalized positions in the 2-dimensional space of `X`.
+ They are used to interpolate output values of `Y[n, c, h_out, w_out]`.
+
+ The GridSample operator is often used in doing grid generator and sampler in the
+ [Spatial Transformer Networks](https://arxiv.org/abs/1506.02025).
+ See also in [torch.nn.functional.grid_sample](https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html).
+ }];
+ let arguments = (ins AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex]>, TensorOf<[Complex]>]>:$X,
+ AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$grid,
+ DefaultValuedAttr:$align_corners,
+ DefaultValuedStrAttr:$mode,
+ DefaultValuedStrAttr:$padding_mode);
+ let results = (outs AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex]>, TensorOf<[Complex]>]>:$Y);
+ let extraClassDeclaration = [{
+ static int getNumberOfOperands() {
+ return 2;
+ }
+ static int getNumberOfResults() {
+ return 1;
+ }
+ static std::vector getTypeMap() {
+ return {30};
+ }
+ }];
+ let extraClassDefinition = [{
+ onnx_mlir::ONNXOpShapeHelper * $cppClass::getShapeHelper(mlir::Operation *op, llvm::ArrayRef oper,
+ onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) {
+ onnx_mlir::ONNXOpShapeHelper *sh = new onnx_mlir::ONNXGridSampleOpShapeHelper(op, oper, ieb, scope);
+ assert(sh && "failed to allocate shape helper");
+ return sh;
+ }
+ }];
+ let hasVerifier = 1;
+}
+
+def ONNXGridSampleV16Op:ONNX_Op<"GridSampleV16",
[Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> {
let summary = "ONNX GridSample operation";
let description = [{
@@ -3099,7 +3150,7 @@ def ONNXGridSampleOp:ONNX_Op<"GridSample",
let extraClassDefinition = [{
onnx_mlir::ONNXOpShapeHelper * $cppClass::getShapeHelper(mlir::Operation *op, llvm::ArrayRef oper,
onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) {
- onnx_mlir::ONNXOpShapeHelper *sh = new onnx_mlir::ONNXGridSampleOpShapeHelper(op, oper, ieb, scope);
+ onnx_mlir::ONNXOpShapeHelper *sh = new onnx_mlir::ONNXGridSampleV16OpShapeHelper(op, oper, ieb, scope);
assert(sh && "failed to allocate shape helper");
return sh;
}
diff --git a/src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp b/src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp
index 01a8943ead..80be997fda 100644
--- a/src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp
+++ b/src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp
@@ -868,6 +868,7 @@ using ONNXDimOpShapeHelper = ONNXNonSpecificOpShapeHelper;
using ONNXDropoutOpShapeHelper = ONNXNonSpecificOpShapeHelper;
using ONNXDynamicQuantizeLinearOpShapeHelper = ONNXNonSpecificOpShapeHelper;
using ONNXEinsumOpShapeHelper = ONNXNonSpecificOpShapeHelper;
+using ONNXGridSampleOpShapeHelper = ONNXNonSpecificOpShapeHelper;
using ONNXEyeLikeOpShapeHelper = ONNXNonSpecificOpShapeHelper;
using ONNXFlattenOpShapeHelper = ONNXNonSpecificOpShapeHelper;
using ONNXGatherElementsOpShapeHelper = ONNXNonSpecificOpShapeHelper;
diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/GridSample.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/GridSample.cpp
new file mode 100644
index 0000000000..c281ca296e
--- /dev/null
+++ b/src/Dialect/ONNX/ONNXOps/Tensor/GridSample.cpp
@@ -0,0 +1,128 @@
+/*
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+//===------------------ GridSample.cpp - ONNX Operations ------------------===//
+//
+// Copyright (c) 2024 Advanced Micro Devices, Inc.
+//
+// =============================================================================
+//
+// This file provides definition of ONNX dialect GridSample operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp"
+
+using namespace mlir;
+using namespace mlir::OpTrait::util;
+using namespace onnx_mlir;
+
+//===----------------------------------------------------------------------===//
+// Support
+//===----------------------------------------------------------------------===//
+
+namespace onnx_mlir {
+
+template <>
+LogicalResult ONNXGridSampleOpShapeHelper::computeShape() {
+
+ // Read data and indices shapes as dim indices.
+ ONNXGridSampleOpAdaptor operandAdaptor(operands);
+ DimsExpr inputDims;
+ DimsExpr gridDims;
+ createIE->getShapeAsDims(operandAdaptor.getX(), inputDims);
+ createIE->getShapeAsDims(operandAdaptor.getGrid(), gridDims);
+
+ int64_t gridRank = gridDims.size();
+
+ // Input's dimensions of rank r+2 should be in the form of (N,C,D1,D2,...,Dr)
+ // Grid's dimensions should also have rank r+2 and be in the form of
+ // (N,D1_out,D2_out,...,Dr_out,r).
+ // The output Y will have shape (N, C, D1_out, D2_out, ..., Dr_out).
+ DimsExpr outputDims;
+ outputDims.emplace_back(inputDims[0]);
+ outputDims.emplace_back(inputDims[1]);
+ for (int i = 1; i < gridRank - 1; ++i) {
+ outputDims.emplace_back(gridDims[i]);
+ }
+
+ setOutputDims(outputDims);
+ return success();
+}
+
+} // namespace onnx_mlir
+
+//===----------------------------------------------------------------------===//
+// Verify
+//===----------------------------------------------------------------------===//
+
+LogicalResult ONNXGridSampleOp::verify() {
+ ONNXGridSampleOpAdaptor operandAdaptor(*this);
+ auto op = mlir::cast(*this);
+
+ const auto alignCorners = op.getAlignCorners();
+ if (alignCorners != 0 && alignCorners != 1) {
+ return emitOpError("align_corners needs to be 0 or 1");
+ }
+ const auto mode = op.getMode();
+ if (mode != "linear" && mode != "nearest" && mode != "cubic") {
+ return emitOpError("mode needs to be linear, nearest or cubic");
+ }
+ const auto paddingMode = op.getPaddingMode();
+ if (paddingMode != "zeros" && paddingMode != "border" &&
+ paddingMode != "reflection") {
+ return emitOpError("padding_mode needs to be zeros, border or reflection");
+ }
+
+ if (!hasShapeAndRank(getOperation()))
+ return success();
+
+ auto inputShape =
+ mlir::cast(operandAdaptor.getX().getType()).getShape();
+ int64_t inputRank = inputShape.size();
+ auto gridShape =
+ mlir::cast(operandAdaptor.getGrid().getType()).getShape();
+
+ // Check whether the ranks of input and grid are valid and are equal.
+ // Input's dimensions of rank r+2 should be in the form of (N,C,D1,D2,...,Dr)
+ // Grid's dimensions should also have rank r+2 and be in the form of
+ // (N,D1_out,D2_out,...,Dr_out,r).
+ if (inputShape.size() != gridShape.size()) {
+ return emitOpError() << "Input(=" << inputShape.size()
+ << ") and grid(=" << gridShape.size()
+ << ") have different dim sizes.";
+ }
+
+ if (inputShape[0] != gridShape[0]) {
+ return emitOpError() << "Input and grid must have the same batch value.";
+ }
+
+ if (!ShapedType::isDynamic(gridShape.back()) &&
+ gridShape.back() != inputRank - 2) {
+ return emitOpError() << "Grid last dim must have been '" << inputRank - 2
+ << "' instead of '" << gridShape.back() << "'.";
+ }
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Shape Inference
+//===----------------------------------------------------------------------===//
+
+LogicalResult ONNXGridSampleOp::inferShapes(
+ std::function /*doShapeInference*/) {
+
+ Type elementType = mlir::cast(getX().getType()).getElementType();
+ ONNXGridSampleOpShapeHelper shapeHelper(getOperation(), {});
+ return shapeHelper.computeShapeAndUpdateType(elementType);
+}
+
+//===----------------------------------------------------------------------===//
+// Template instantiation
+//===----------------------------------------------------------------------===//
+
+namespace onnx_mlir {
+template struct ONNXNonSpecificOpShapeHelper;
+} // namespace onnx_mlir
diff --git a/src/Dialect/ONNX/ONNXUnsupportedOps.hpp b/src/Dialect/ONNX/ONNXUnsupportedOps.hpp
index 8a43b3e4a1..c32d07c672 100644
--- a/src/Dialect/ONNX/ONNXUnsupportedOps.hpp
+++ b/src/Dialect/ONNX/ONNXUnsupportedOps.hpp
@@ -37,7 +37,6 @@ UNSUPPORTED_OPS(ONNXDeformConvOp)
UNSUPPORTED_OPS(ONNXDictVectorizerOp)
UNSUPPORTED_OPS(ONNXFeatureVectorizerOp)
UNSUPPORTED_OPS(ONNXGradientOp)
-UNSUPPORTED_OPS(ONNXGridSampleOp)
UNSUPPORTED_OPS(ONNXHammingWindowOp)
UNSUPPORTED_OPS(ONNXHannWindowOp)
UNSUPPORTED_OPS(ONNXImputerOp)
@@ -76,6 +75,7 @@ CONVERTED_TO_SUPPORTED_OPS(ONNXClipV11Op)
CONVERTED_TO_SUPPORTED_OPS(ONNXClipV12Op)
CONVERTED_TO_SUPPORTED_OPS(ONNXClipV6Op)
CONVERTED_TO_SUPPORTED_OPS(ONNXDFTV17Op)
+CONVERTED_TO_SUPPORTED_OPS(ONNXGridSampleV16Op)
CONVERTED_TO_SUPPORTED_OPS(ONNXGroupNormalizationOp)
CONVERTED_TO_SUPPORTED_OPS(ONNXGroupNormalizationV18Op)
CONVERTED_TO_SUPPORTED_OPS(ONNXPadV18Op)
diff --git a/src/Dialect/ONNX/Transforms/Decompose.cpp b/src/Dialect/ONNX/Transforms/Decompose.cpp
index ed9e0fb150..d61a980e15 100644
--- a/src/Dialect/ONNX/Transforms/Decompose.cpp
+++ b/src/Dialect/ONNX/Transforms/Decompose.cpp
@@ -387,6 +387,18 @@ bool canSequenceAtBeReplaced(Value sequenceAtResult) {
return true;
}
+Attribute upgradeGridSampleV16Mode(PatternRewriter &rewriter, Attribute mode) {
+ const auto stringMode = mlir::cast(mode);
+ if (stringMode.strref() == "bilinear") {
+ return rewriter.getStringAttr("linear");
+ }
+ if (stringMode.strref() == "bicubic") {
+ return rewriter.getStringAttr("cubic");
+ }
+ assert(stringMode.strref() == "nearest");
+ return mode;
+}
+
Value replaceSequenceAt(
PatternRewriter &rewriter, Location loc, Value sequenceAtResult) {
ONNXSequenceAtOp op = sequenceAtResult.getDefiningOp();
@@ -1318,6 +1330,7 @@ void DecomposeONNXToONNXPass::runOnOperation() {
target.addIllegalOp();
target.addIllegalOp();
target.addIllegalOp();
+ target.addIllegalOp();
target.addIllegalOp();
target.addIllegalOp();
target.addIllegalOp();
diff --git a/src/Dialect/ONNX/Transforms/Decompose.td b/src/Dialect/ONNX/Transforms/Decompose.td
index 00ae9f6ff3..bc7044b524 100644
--- a/src/Dialect/ONNX/Transforms/Decompose.td
+++ b/src/Dialect/ONNX/Transforms/Decompose.td
@@ -73,6 +73,9 @@ def ReshapeElementsAttrToRank0 : NativeCodeCall<
def ReplaceSequenceAt : NativeCodeCall<
"onnx_mlir::replaceSequenceAt($_builder, $_loc, $0)">;
+
+def UpgradeGridSampleV16Mode : NativeCodeCall<
+ "onnx_mlir::upgradeGridSampleV16Mode($_builder, $0)">;
def CanSequenceAtBeReplaced :
Constraint, "check whether the SequenceAt can be replaced with split">;
@@ -365,6 +368,12 @@ def ClipV12Pattern : Pat<
(ONNXClipOp $x, $min, $max)
>;
+// Rewrite GridSample 16 to GridSample 22
+def GridSampleV16Pattern : Pat<
+ (ONNXGridSampleV16Op $x, $grid, $align_corners, $mode, $padding_mode),
+ (ONNXGridSampleOp $x, $grid, $align_corners, (UpgradeGridSampleV16Mode $mode), $padding_mode)
+>;
+
def DFTV17Pattern : Pat<
(ONNXDFTV17Op $x, $dft_length, $axis, $inverse, $onesided),
(ONNXDFTOp $x, $dft_length, (ONNXConstantOpFromDenseAttr(createScalarDenseAttrRank0 $axis)), $inverse, $onesided)
diff --git a/test/mlir/onnx/invalid.mlir b/test/mlir/onnx/invalid.mlir
index 7b8d087c03..763de9043a 100644
--- a/test/mlir/onnx/invalid.mlir
+++ b/test/mlir/onnx/invalid.mlir
@@ -719,3 +719,51 @@ func.func @test_matmulinteger_wrong_B_broadcast(%arg0: tensor<16x32xui8>, %arg1:
%0 = "onnx.MatMulInteger"(%arg0, %arg1, %arg2, %arg3) : (tensor<16x32xui8>, tensor<5x32x64xui8>, tensor<16xui8>, tensor<5x1x2xui8>) -> tensor<5x16x64xi32>
onnx.Return %0 : tensor<5x16x64xi32>
}
+
+// -----
+
+func.func @test_grid_sample_diff_ranks(%arg0: tensor<1x3x1152x1344xf32>, %arg1: tensor<1x1152x2xf32>) -> tensor<*xf32> {
+ // expected-error @+1 {{'onnx.GridSample' op Input(=4) and grid(=3) have different dim sizes.}}
+ %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", padding_mode = "border"} : (tensor<1x3x1152x1344xf32>, tensor<1x1152x2xf32>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+}
+
+// -----
+
+func.func @test_grid_sample_diff_batch(%arg0: tensor<1x1x4x4xf32>, %arg1: tensor<2x6x6x2xf32>) -> tensor<*xf32> {
+ // expected-error @+1 {{'onnx.GridSample' op Input and grid must have the same batch value.}}
+ %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", padding_mode = "border"} : (tensor<1x1x4x4xf32>, tensor<2x6x6x2xf32>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+}
+
+// -----
+
+func.func @test_grid_sample_align_corners(%arg0: tensor<2x1x4x4xf32>, %arg1: tensor<2x6x6x2xf32>) -> tensor<*xf32> {
+ // expected-error @+1 {{'onnx.GridSample' op align_corners needs to be 0 or 1}}
+ %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 2 : si64, mode = "linear", padding_mode = "border"} : (tensor<2x1x4x4xf32>, tensor<2x6x6x2xf32>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+}
+
+// -----
+
+func.func @test_grid_sample_mode(%arg0: tensor<2x1x4x4xf32>, %arg1: tensor<2x6x6x2xf32>) -> tensor<*xf32> {
+ // expected-error @+1 {{'onnx.GridSample' op mode needs to be linear, nearest or cubic}}
+ %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "sampling", padding_mode = "border"} : (tensor<2x1x4x4xf32>, tensor<2x6x6x2xf32>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+}
+
+// -----
+
+func.func @test_grid_sample_padding(%arg0: tensor<2x1x4x4xf32>, %arg1: tensor<2x6x6x2xf32>) -> tensor<*xf32> {
+ // expected-error @+1 {{'onnx.GridSample' op padding_mode needs to be zeros, border or reflection}}
+ %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "cubic", padding_mode = "bottom"} : (tensor<2x1x4x4xf32>, tensor<2x6x6x2xf32>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+}
+
+// -----
+
+func.func @test_grid_sample_wrong_dim_grid(%arg0: tensor<1x1x4x4xf32>, %arg1: tensor<1x6x6x3xf32>) -> tensor<*xf32> {
+ // expected-error @+1 {{'onnx.GridSample' op Grid last dim must have been '2' instead of '3'.}}
+ %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", padding_mode = "border"} : (tensor<1x1x4x4xf32>, tensor<1x6x6x3xf32>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+}
\ No newline at end of file
diff --git a/test/mlir/onnx/onnx_decompose.mlir b/test/mlir/onnx/onnx_decompose.mlir
index 0f8ac9f554..5a4bfc02a0 100644
--- a/test/mlir/onnx/onnx_decompose.mlir
+++ b/test/mlir/onnx/onnx_decompose.mlir
@@ -2,6 +2,42 @@
// -----
+func.func @test_grid_sample_v16_bicubic(%arg0: tensor<2x1x4x4xf32>, %arg1: tensor<2x6x6x2xf32>) -> tensor<*xf32> {
+ %0 = "onnx.GridSampleV16"(%arg0, %arg1) {align_corners = 1 : si64, mode = "bicubic", padding_mode = "zeros"} : (tensor<2x1x4x4xf32>, tensor<2x6x6x2xf32>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+// CHECK-LABEL: func.func @test_grid_sample_v16_bicubic
+// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x1x4x4xf32>, [[PARAM_1_:%.+]]: tensor<2x6x6x2xf32>) -> tensor<*xf32> {
+// CHECK: [[VAR_0_:%.+]] = "onnx.GridSample"([[PARAM_0_]], [[PARAM_1_]]) {align_corners = 1 : si64, mode = "cubic", padding_mode = "zeros"} : (tensor<2x1x4x4xf32>, tensor<2x6x6x2xf32>) -> tensor<*xf32>
+// CHECK: return [[VAR_0_]] : tensor<*xf32>
+// CHECK: }
+}
+
+// -----
+
+func.func @test_grid_sample_v16_bilinear(%arg0: tensor<2x1x4x4xf32>, %arg1: tensor<2x6x6x2xf32>) -> tensor<*xf32> {
+ %0 = "onnx.GridSampleV16"(%arg0, %arg1) {align_corners = 1 : si64, mode = "bilinear", padding_mode = "zeros"} : (tensor<2x1x4x4xf32>, tensor<2x6x6x2xf32>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+// CHECK-LABEL: func.func @test_grid_sample_v16_bilinear
+// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x1x4x4xf32>, [[PARAM_1_:%.+]]: tensor<2x6x6x2xf32>) -> tensor<*xf32> {
+// CHECK: [[VAR_0_:%.+]] = "onnx.GridSample"([[PARAM_0_]], [[PARAM_1_]]) {align_corners = 1 : si64, mode = "linear", padding_mode = "zeros"} : (tensor<2x1x4x4xf32>, tensor<2x6x6x2xf32>) -> tensor<*xf32>
+// CHECK: return [[VAR_0_]] : tensor<*xf32>
+// CHECK: }
+}
+
+// -----
+
+func.func @test_grid_sample_v16_nearest(%arg0: tensor<2x1x4x4xf32>, %arg1: tensor<2x6x6x2xf32>) -> tensor<*xf32> {
+ %0 = "onnx.GridSampleV16"(%arg0, %arg1) {align_corners = 1 : si64, mode = "nearest", padding_mode = "zeros"} : (tensor<2x1x4x4xf32>, tensor<2x6x6x2xf32>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+// CHECK-LABEL: func.func @test_grid_sample_v16_nearest
+// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x1x4x4xf32>, [[PARAM_1_:%.+]]: tensor<2x6x6x2xf32>) -> tensor<*xf32> {
+// CHECK: [[VAR_0_:%.+]] = "onnx.GridSample"([[PARAM_0_]], [[PARAM_1_]]) {align_corners = 1 : si64, mode = "nearest", padding_mode = "zeros"} : (tensor<2x1x4x4xf32>, tensor<2x6x6x2xf32>) -> tensor<*xf32>
+// CHECK: return [[VAR_0_]] : tensor<*xf32>
+// CHECK: }
+}
+
+// -----
+
func.func @test_dft(%arg0 : tensor, %arg1 : tensor) -> tensor<*xf32> {
%cst = "onnx.NoValue"() {value} : () -> none
%0 ="onnx.DFTV17"(%arg0, %arg1) : (tensor, tensor)-> tensor<*xf32>
diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir
index 4a625c9405..478ea172b5 100644
--- a/test/mlir/onnx/onnx_shape_inference.mlir
+++ b/test/mlir/onnx/onnx_shape_inference.mlir
@@ -3819,3 +3819,75 @@ func.func @test_RMSlayer_norm_2inputs(%arg0: tensor<12x3x5xf32>, %arg1: tensor<5
// CHECK: }
}
+// -----
+
+// Test Grid Sample
+
+func.func @test_grid_sample_same_dims(%arg0: tensor<1x3x1152x1344xf32>, %arg1: tensor<1x1152x1344x2xf32>) -> tensor<*xf32> {
+ %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x3x1152x1344xf32>, tensor<1x1152x1344x2xf32>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+
+// mlir2FileCheck.py
+// CHECK-LABEL: func.func @test_grid_sample_same_dims
+// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x3x1152x1344xf32>, [[PARAM_1_:%.+]]: tensor<1x1152x1344x2xf32>) -> tensor<1x3x1152x1344xf32> {
+// CHECK: [[GRID:%.+]] = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x3x1152x1344xf32>, tensor<1x1152x1344x2xf32>) -> tensor<1x3x1152x1344xf32>
+// CHECK: return [[GRID]] : tensor<1x3x1152x1344xf32>
+// CHECK: }
+}
+
+func.func @test_grid_sample_diff_dims(%arg0: tensor<1x1x4x4xf32>, %arg1: tensor<1x6x6x2xf32>) -> tensor<*xf32> {
+ %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x1x4x4xf32>, tensor<1x6x6x2xf32>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+
+// mlir2FileCheck.py
+// CHECK-LABEL: func.func @test_grid_sample_diff_dims
+// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x4x4xf32>, [[PARAM_1_:%.+]]: tensor<1x6x6x2xf32>) -> tensor<1x1x6x6xf32> {
+// CHECK: [[GRID:%.+]] = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x1x4x4xf32>, tensor<1x6x6x2xf32>) -> tensor<1x1x6x6xf32>
+// CHECK: return [[GRID]] : tensor<1x1x6x6xf32>
+// CHECK: }
+}
+
+func.func @test_grid_sample_6d(%arg0: tensor<1x2x4x4x4x4xf32>, %arg1: tensor<1x6x6x4x4x4xf32>) -> tensor<*xf32> {
+ %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x2x4x4x4x4xf32>, tensor<1x6x6x4x4x4xf32>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+
+// mlir2FileCheck.py
+// CHECK-LABEL: func.func @test_grid_sample_6d
+// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x2x4x4x4x4xf32>, [[PARAM_1_:%.+]]: tensor<1x6x6x4x4x4xf32>) -> tensor<1x2x6x6x4x4xf32> {
+// CHECK: [[GRID:%.+]] = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x2x4x4x4x4xf32>, tensor<1x6x6x4x4x4xf32>) -> tensor<1x2x6x6x4x4xf32>
+// CHECK: return [[GRID]] : tensor<1x2x6x6x4x4xf32>
+// CHECK: }
+}
+
+func.func @test_grid_sample_dim_shape(%arg0: tensor, %arg1: tensor) -> tensor<*xf32> {
+ %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor, tensor) -> tensor<*xf32>
+// mlir2FileCheck.py
+// CHECK-LABEL: func.func @test_grid_sample_dim_shape
+// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor) -> tensor {
+// CHECK: [[GRID:%.+]] = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor, tensor) -> tensor
+// CHECK: return [[GRID]] : tensor
+// CHECK: }
+ return %0 : tensor<*xf32>
+}
+
+func.func @test_grid_sample_dim_shape2(%arg0: tensor, %arg1: tensor) -> tensor<*xf32> {
+ %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor, tensor) -> tensor<*xf32>
+// mlir2FileCheck.py
+// CHECK-LABEL: func.func @test_grid_sample_dim_shape2
+// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor) -> tensor {
+// CHECK: [[GRID:%.+]] = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor, tensor) -> tensor
+// CHECK: return [[GRID]] : tensor
+// CHECK: }
+ return %0 : tensor<*xf32>
+}
+
+func.func @test_grid_sample_dim_shape3(%arg0: tensor, %arg1: tensor) -> tensor<*xf32> {
+ %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor, tensor) -> tensor<*xf32>
+// mlir2FileCheck.py
+// CHECK-LABEL: func.func @test_grid_sample_dim_shape3
+// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor) -> tensor {
+// CHECK: [[GRID:%.+]] = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor, tensor) -> tensor
+// CHECK: return [[GRID]] : tensor
+// CHECK: }
+ return %0 : tensor<*xf32>
+}
\ No newline at end of file
diff --git a/utils/gen_onnx_mlir.py b/utils/gen_onnx_mlir.py
index a32c931521..e1473ad92d 100755
--- a/utils/gen_onnx_mlir.py
+++ b/utils/gen_onnx_mlir.py
@@ -160,7 +160,7 @@
"Gradient": [1],
"Greater": [13],
"GreaterOrEqual": [16],
- "GridSample": [16],
+ "GridSample": [22, 16],
"GroupNormalization": [21, 18],
"HammingWindow": [17],
"HannWindow": [17],
@@ -396,6 +396,7 @@
"Gelu",
"Greater",
"GreaterOrEqual",
+ "GridSample",
"GroupNormalizationV18",
"Hardmax",
"If",