Skip to content

Commit

Permalink
Extend GridSample support (#3060)
Browse files Browse the repository at this point in the history
* feat: add verifier and shape inference for Onnx.GridSample.

Signed-off-by: Rickert, Jonas <[email protected]>

* Support onnx.GridSampleV22

Signed-off-by: Rickert, Jonas <[email protected]>

* Check attributes in GridSample verifier

Signed-off-by: Rickert, Jonas <[email protected]>

---------

Signed-off-by: Rickert, Jonas <[email protected]>
Co-authored-by: Tiago Trevisan Jost <[email protected]>
  • Loading branch information
jorickert and ttjost authored Feb 3, 2025
1 parent 8530104 commit d35d593
Show file tree
Hide file tree
Showing 13 changed files with 417 additions and 4 deletions.
51 changes: 51 additions & 0 deletions docs/Dialects/onnx.md
Original file line number Diff line number Diff line change
Expand Up @@ -3531,6 +3531,57 @@ Effects: `MemoryEffects::Effect{}`

_ONNX GridSample operation_

Given an input `X` and a flow-field `grid`, computes the output `Y` using `X` values and pixel locations from the `grid`.
For spatial input `X` with shape (N, C, H, W), the `grid` will have shape (N, H_out, W_out, 2),
the output `Y` will have shape (N, C, H_out, W_out). For volumetric input `X` with shape (N, C, D, H, W),
the `grid` will have shape (N, D_out, H_out, W_out, 3), the output `Y` will have shape (N, C, D_out, H_out, W_out).
More generally, for an input `X` of rank r+2 with shape (N, C, d1, d2, ..., dr),
the `grid` will have shape (N, D1_out, D2_out, ..., Dr_out, r), the output `Y` will have shape (N, C, D1_out, D2_out, ..., Dr_out).

The tensor `X` contains values at centers of square pixels (voxels, etc) locations such as (n, c, d1_in, d2_in, ..., dr_in).
The (n, d1_out, d2_out, ..., dr_out, :) values from the tensor `grid` are the normalized positions for interpolating the values
at the (n, c, d1_out, d2_out, ..., dr_out) locations from the output tensor `Y` using a specified interpolation method (the mode)
and a padding mode (for `grid` positions falling outside the 2-dimensional image).

For example, the values in `grid[n, h_out, w_out, :]` are size-2 vectors specifying normalized positions in the 2-dimensional space of `X`.
They are used to interpolate output values of `Y[n, c, h_out, w_out]`.

The GridSample operator is often used in doing grid generator and sampler in the
[Spatial Transformer Networks](https://arxiv.org/abs/1506.02025).
See also in [torch.nn.functional.grid_sample](https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html).

Traits: `AlwaysSpeculatableImplTrait`

Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface`

Effects: `MemoryEffects::Effect{}`

#### Attributes:

<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>align_corners</code></td><td>::mlir::IntegerAttr</td><td>64-bit signed integer attribute</td></tr>
<tr><td><code>mode</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
<tr><td><code>padding_mode</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
</table>

#### Operands:

| Operand | Description |
| :-----: | ----------- |
| `X` | tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of string type values or tensor of 1-bit signless integer values or tensor of complex type with 32-bit float elements values or tensor of complex type with 64-bit float elements values
| `grid` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values

#### Results:

| Result | Description |
| :----: | ----------- |
| `Y` | tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of string type values or tensor of 1-bit signless integer values or tensor of complex type with 32-bit float elements values or tensor of complex type with 64-bit float elements values

### `onnx.GridSampleV16` (ONNXGridSampleV16Op)

_ONNX GridSample operation_

Given an input `X` and a flow-field `grid`, computes the output `Y` using `X` values and pixel locations from `grid`.
Currently, only spatial (4-D) inputs are supported. For input `X` with shape (N, C, H, W) and `grid` with shape (N, H_out, W_out, 2),
the output `Y` will have shape (N, C, H_out, W_out).
Expand Down
4 changes: 3 additions & 1 deletion src/Builder/OpBuildTable.inc
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ op_dialect_version_map_["GlobalMaxPool"] = {1};
op_dialect_version_map_["Gradient"] = {1};
op_dialect_version_map_["Greater"] = {13};
op_dialect_version_map_["GreaterOrEqual"] = {16};
op_dialect_version_map_["GridSample"] = {16};
op_dialect_version_map_["GridSample"] = {22, 16};
op_dialect_version_map_["GroupNormalization"] = {21, 18};
op_dialect_version_map_["HammingWindow"] = {17};
op_dialect_version_map_["HannWindow"] = {17};
Expand Down Expand Up @@ -356,6 +356,8 @@ import_handler_map_["GreaterOrEqual"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXGreaterOrEqualOp>;
import_handler_map_["GridSample"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXGridSampleOp>;
import_handler_map_["GridSampleV16"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXGridSampleV16Op>;
import_handler_map_["GroupNormalization"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXGroupNormalizationOp>;
import_handler_map_["GroupNormalizationV18"] =
Expand Down
1 change: 1 addition & 0 deletions src/Dialect/ONNX/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ add_onnx_mlir_library(OMONNXOps
ONNXOps/Tensor/Gather.cpp
ONNXOps/Tensor/GatherElements.cpp
ONNXOps/Tensor/GatherND.cpp
ONNXOps/Tensor/GridSample.cpp
ONNXOps/Tensor/Identity.cpp
ONNXOps/Tensor/NonZero.cpp
ONNXOps/Tensor/OneHot.cpp
Expand Down
53 changes: 52 additions & 1 deletion src/Dialect/ONNX/ONNXOps.td.inc
Original file line number Diff line number Diff line change
Expand Up @@ -3062,6 +3062,57 @@ def ONNXGreaterOrEqualOp:ONNX_Op<"GreaterOrEqual",
}

def ONNXGridSampleOp:ONNX_Op<"GridSample",
[Pure, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, DeclareOpInterfaceMethods<ShapeHelperOpInterface>]> {
let summary = "ONNX GridSample operation";
let description = [{
Given an input `X` and a flow-field `grid`, computes the output `Y` using `X` values and pixel locations from the `grid`.
For spatial input `X` with shape (N, C, H, W), the `grid` will have shape (N, H_out, W_out, 2),
the output `Y` will have shape (N, C, H_out, W_out). For volumetric input `X` with shape (N, C, D, H, W),
the `grid` will have shape (N, D_out, H_out, W_out, 3), the output `Y` will have shape (N, C, D_out, H_out, W_out).
More generally, for an input `X` of rank r+2 with shape (N, C, d1, d2, ..., dr),
the `grid` will have shape (N, D1_out, D2_out, ..., Dr_out, r), the output `Y` will have shape (N, C, D1_out, D2_out, ..., Dr_out).

The tensor `X` contains values at centers of square pixels (voxels, etc) locations such as (n, c, d1_in, d2_in, ..., dr_in).
The (n, d1_out, d2_out, ..., dr_out, :) values from the tensor `grid` are the normalized positions for interpolating the values
at the (n, c, d1_out, d2_out, ..., dr_out) locations from the output tensor `Y` using a specified interpolation method (the mode)
and a padding mode (for `grid` positions falling outside the 2-dimensional image).

For example, the values in `grid[n, h_out, w_out, :]` are size-2 vectors specifying normalized positions in the 2-dimensional space of `X`.
They are used to interpolate output values of `Y[n, c, h_out, w_out]`.

The GridSample operator is often used in doing grid generator and sampler in the
[Spatial Transformer Networks](https://arxiv.org/abs/1506.02025).
See also in [torch.nn.functional.grid_sample](https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html).
}];
let arguments = (ins AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>]>:$X,
AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$grid,
DefaultValuedAttr<SI64Attr, "0">:$align_corners,
DefaultValuedStrAttr<StrAttr, "linear">:$mode,
DefaultValuedStrAttr<StrAttr, "zeros">:$padding_mode);
let results = (outs AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>]>:$Y);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 2;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {30};
}
}];
let extraClassDefinition = [{
onnx_mlir::ONNXOpShapeHelper * $cppClass::getShapeHelper(mlir::Operation *op, llvm::ArrayRef<mlir::Value> oper,
onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) {
onnx_mlir::ONNXOpShapeHelper *sh = new onnx_mlir::ONNXGridSampleOpShapeHelper(op, oper, ieb, scope);
assert(sh && "failed to allocate shape helper");
return sh;
}
}];
let hasVerifier = 1;
}

def ONNXGridSampleV16Op:ONNX_Op<"GridSampleV16",
[Pure, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, DeclareOpInterfaceMethods<ShapeHelperOpInterface>]> {
let summary = "ONNX GridSample operation";
let description = [{
Expand Down Expand Up @@ -3099,7 +3150,7 @@ def ONNXGridSampleOp:ONNX_Op<"GridSample",
let extraClassDefinition = [{
onnx_mlir::ONNXOpShapeHelper * $cppClass::getShapeHelper(mlir::Operation *op, llvm::ArrayRef<mlir::Value> oper,
onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) {
onnx_mlir::ONNXOpShapeHelper *sh = new onnx_mlir::ONNXGridSampleOpShapeHelper(op, oper, ieb, scope);
onnx_mlir::ONNXOpShapeHelper *sh = new onnx_mlir::ONNXGridSampleV16OpShapeHelper(op, oper, ieb, scope);
assert(sh && "failed to allocate shape helper");
return sh;
}
Expand Down
1 change: 1 addition & 0 deletions src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,7 @@ using ONNXDimOpShapeHelper = ONNXNonSpecificOpShapeHelper<mlir::ONNXDimOp>;
using ONNXDropoutOpShapeHelper = ONNXNonSpecificOpShapeHelper<mlir::ONNXDropoutOp>;
using ONNXDynamicQuantizeLinearOpShapeHelper = ONNXNonSpecificOpShapeHelper<mlir::ONNXDynamicQuantizeLinearOp>;
using ONNXEinsumOpShapeHelper = ONNXNonSpecificOpShapeHelper<mlir::ONNXEinsumOp>;
using ONNXGridSampleOpShapeHelper = ONNXNonSpecificOpShapeHelper<mlir::ONNXGridSampleOp>;
using ONNXEyeLikeOpShapeHelper = ONNXNonSpecificOpShapeHelper<mlir::ONNXEyeLikeOp>;
using ONNXFlattenOpShapeHelper = ONNXNonSpecificOpShapeHelper<mlir::ONNXFlattenOp>;
using ONNXGatherElementsOpShapeHelper = ONNXNonSpecificOpShapeHelper<mlir::ONNXGatherElementsOp>;
Expand Down
128 changes: 128 additions & 0 deletions src/Dialect/ONNX/ONNXOps/Tensor/GridSample.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//===------------------ GridSample.cpp - ONNX Operations ------------------===//
//
// Copyright (c) 2024 Advanced Micro Devices, Inc.
//
// =============================================================================
//
// This file provides definition of ONNX dialect GridSample operation.
//
//===----------------------------------------------------------------------===//

#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp"

using namespace mlir;
using namespace mlir::OpTrait::util;
using namespace onnx_mlir;

//===----------------------------------------------------------------------===//
// Support
//===----------------------------------------------------------------------===//

namespace onnx_mlir {

template <>
LogicalResult ONNXGridSampleOpShapeHelper::computeShape() {

// Read data and indices shapes as dim indices.
ONNXGridSampleOpAdaptor operandAdaptor(operands);
DimsExpr inputDims;
DimsExpr gridDims;
createIE->getShapeAsDims(operandAdaptor.getX(), inputDims);
createIE->getShapeAsDims(operandAdaptor.getGrid(), gridDims);

int64_t gridRank = gridDims.size();

// Input's dimensions of rank r+2 should be in the form of (N,C,D1,D2,...,Dr)
// Grid's dimensions should also have rank r+2 and be in the form of
// (N,D1_out,D2_out,...,Dr_out,r).
// The output Y will have shape (N, C, D1_out, D2_out, ..., Dr_out).
DimsExpr outputDims;
outputDims.emplace_back(inputDims[0]);
outputDims.emplace_back(inputDims[1]);
for (int i = 1; i < gridRank - 1; ++i) {
outputDims.emplace_back(gridDims[i]);
}

setOutputDims(outputDims);
return success();
}

} // namespace onnx_mlir

//===----------------------------------------------------------------------===//
// Verify
//===----------------------------------------------------------------------===//

LogicalResult ONNXGridSampleOp::verify() {
ONNXGridSampleOpAdaptor operandAdaptor(*this);
auto op = mlir::cast<ONNXGridSampleOp>(*this);

const auto alignCorners = op.getAlignCorners();
if (alignCorners != 0 && alignCorners != 1) {
return emitOpError("align_corners needs to be 0 or 1");
}
const auto mode = op.getMode();
if (mode != "linear" && mode != "nearest" && mode != "cubic") {
return emitOpError("mode needs to be linear, nearest or cubic");
}
const auto paddingMode = op.getPaddingMode();
if (paddingMode != "zeros" && paddingMode != "border" &&
paddingMode != "reflection") {
return emitOpError("padding_mode needs to be zeros, border or reflection");
}

if (!hasShapeAndRank(getOperation()))
return success();

auto inputShape =
mlir::cast<ShapedType>(operandAdaptor.getX().getType()).getShape();
int64_t inputRank = inputShape.size();
auto gridShape =
mlir::cast<ShapedType>(operandAdaptor.getGrid().getType()).getShape();

// Check whether the ranks of input and grid are valid and are equal.
// Input's dimensions of rank r+2 should be in the form of (N,C,D1,D2,...,Dr)
// Grid's dimensions should also have rank r+2 and be in the form of
// (N,D1_out,D2_out,...,Dr_out,r).
if (inputShape.size() != gridShape.size()) {
return emitOpError() << "Input(=" << inputShape.size()
<< ") and grid(=" << gridShape.size()
<< ") have different dim sizes.";
}

if (inputShape[0] != gridShape[0]) {
return emitOpError() << "Input and grid must have the same batch value.";
}

if (!ShapedType::isDynamic(gridShape.back()) &&
gridShape.back() != inputRank - 2) {
return emitOpError() << "Grid last dim must have been '" << inputRank - 2
<< "' instead of '" << gridShape.back() << "'.";
}

return success();
}

//===----------------------------------------------------------------------===//
// Shape Inference
//===----------------------------------------------------------------------===//

LogicalResult ONNXGridSampleOp::inferShapes(
std::function<void(Region &)> /*doShapeInference*/) {

Type elementType = mlir::cast<ShapedType>(getX().getType()).getElementType();
ONNXGridSampleOpShapeHelper shapeHelper(getOperation(), {});
return shapeHelper.computeShapeAndUpdateType(elementType);
}

//===----------------------------------------------------------------------===//
// Template instantiation
//===----------------------------------------------------------------------===//

namespace onnx_mlir {
template struct ONNXNonSpecificOpShapeHelper<ONNXGridSampleOp>;
} // namespace onnx_mlir
2 changes: 1 addition & 1 deletion src/Dialect/ONNX/ONNXUnsupportedOps.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ UNSUPPORTED_OPS(ONNXDeformConvOp)
UNSUPPORTED_OPS(ONNXDictVectorizerOp)
UNSUPPORTED_OPS(ONNXFeatureVectorizerOp)
UNSUPPORTED_OPS(ONNXGradientOp)
UNSUPPORTED_OPS(ONNXGridSampleOp)
UNSUPPORTED_OPS(ONNXHammingWindowOp)
UNSUPPORTED_OPS(ONNXHannWindowOp)
UNSUPPORTED_OPS(ONNXImputerOp)
Expand Down Expand Up @@ -76,6 +75,7 @@ CONVERTED_TO_SUPPORTED_OPS(ONNXClipV11Op)
CONVERTED_TO_SUPPORTED_OPS(ONNXClipV12Op)
CONVERTED_TO_SUPPORTED_OPS(ONNXClipV6Op)
CONVERTED_TO_SUPPORTED_OPS(ONNXDFTV17Op)
CONVERTED_TO_SUPPORTED_OPS(ONNXGridSampleV16Op)
CONVERTED_TO_SUPPORTED_OPS(ONNXGroupNormalizationOp)
CONVERTED_TO_SUPPORTED_OPS(ONNXGroupNormalizationV18Op)
CONVERTED_TO_SUPPORTED_OPS(ONNXPadV18Op)
Expand Down
13 changes: 13 additions & 0 deletions src/Dialect/ONNX/Transforms/Decompose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,18 @@ bool canSequenceAtBeReplaced(Value sequenceAtResult) {
return true;
}

Attribute upgradeGridSampleV16Mode(PatternRewriter &rewriter, Attribute mode) {
const auto stringMode = mlir::cast<StringAttr>(mode);
if (stringMode.strref() == "bilinear") {
return rewriter.getStringAttr("linear");
}
if (stringMode.strref() == "bicubic") {
return rewriter.getStringAttr("cubic");
}
assert(stringMode.strref() == "nearest");
return mode;
}

Value replaceSequenceAt(
PatternRewriter &rewriter, Location loc, Value sequenceAtResult) {
ONNXSequenceAtOp op = sequenceAtResult.getDefiningOp<ONNXSequenceAtOp>();
Expand Down Expand Up @@ -1318,6 +1330,7 @@ void DecomposeONNXToONNXPass::runOnOperation() {
target.addIllegalOp<ONNXClipV6Op>();
target.addIllegalOp<ONNXConstantOfShapeOp>();
target.addIllegalOp<ONNXDFTV17Op>();
target.addIllegalOp<ONNXGridSampleV16Op>();
target.addIllegalOp<ONNXGroupNormalizationOp>();
target.addIllegalOp<ONNXGroupNormalizationV18Op>();
target.addIllegalOp<ONNXInstanceNormalizationOp>();
Expand Down
9 changes: 9 additions & 0 deletions src/Dialect/ONNX/Transforms/Decompose.td
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ def ReshapeElementsAttrToRank0 : NativeCodeCall<

def ReplaceSequenceAt : NativeCodeCall<
"onnx_mlir::replaceSequenceAt($_builder, $_loc, $0)">;

def UpgradeGridSampleV16Mode : NativeCodeCall<
"onnx_mlir::upgradeGridSampleV16Mode($_builder, $0)">;

def CanSequenceAtBeReplaced :
Constraint<CPred<"::onnx_mlir::canSequenceAtBeReplaced($_self)">, "check whether the SequenceAt can be replaced with split">;
Expand Down Expand Up @@ -365,6 +368,12 @@ def ClipV12Pattern : Pat<
(ONNXClipOp $x, $min, $max)
>;

// Rewrite GridSample 16 to GridSample 22
def GridSampleV16Pattern : Pat<
(ONNXGridSampleV16Op $x, $grid, $align_corners, $mode, $padding_mode),
(ONNXGridSampleOp $x, $grid, $align_corners, (UpgradeGridSampleV16Mode $mode), $padding_mode)
>;

def DFTV17Pattern : Pat<
(ONNXDFTV17Op $x, $dft_length, $axis, $inverse, $onesided),
(ONNXDFTOp $x, $dft_length, (ONNXConstantOpFromDenseAttr(createScalarDenseAttrRank0 $axis)), $inverse, $onesided)
Expand Down
Loading

0 comments on commit d35d593

Please sign in to comment.