Skip to content

Commit

Permalink
[AutoBump] Merge with adafd51 (May 10)
Browse files Browse the repository at this point in the history
  • Loading branch information
mgehre-amd committed Aug 27, 2024
2 parents 37f6ca0 + adafd51 commit f5ff46e
Show file tree
Hide file tree
Showing 8 changed files with 286 additions and 100 deletions.
65 changes: 38 additions & 27 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1463,56 +1463,67 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
});
patterns.onOp(
"CumSum", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Location loc = binder.getLoc();
Torch::ValueTensorType resultType;
Value operand;
Value axisTensor;
Value operand, axisTensor;
int64_t exclusive, reverse;
if (binder.tensorOperands(operand, axisTensor) ||
binder.s64IntegerAttr(exclusive, "exclusive", 0) ||
binder.s64IntegerAttr(reverse, "reverse", 0) ||
binder.tensorResultType(resultType))
return failure();

int64_t exclusive;
int64_t reverse;
// if bind succeeds and either is set, fail because not implemented
if (!binder.s64IntegerAttr(exclusive, "exclusive", 0))
if (exclusive != 0)
return rewriter.notifyMatchFailure(
binder.op, "unsupported onnx.CumSum conversion: exclusive");
if (!binder.s64IntegerAttr(reverse, "reverse", 0))
if (reverse != 0)
return rewriter.notifyMatchFailure(
binder.op, "unsupported onnx.CumSum conversion: reverse");
Torch::BaseTensorType resultTensorType =
cast<Torch::BaseTensorType>(resultType);
if (!resultTensorType.hasDtype()) {
return rewriter.notifyMatchFailure(
binder.op, "expected result type to have a dtype");
}

// deal with neg axis: if (axis < 0) axis += rank
int64_t rank =
cast<Torch::ValueTensorType>(operand.getType()).getSizes().size();
Value rankVal = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), rank));
Value zero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value cstZero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0));
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));

Value axisScalar = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), axisTensor);
Value isNegative = rewriter.create<Torch::AtenLtIntOp>(
binder.getLoc(), axisScalar, zero);
binder.getLoc(), axisScalar, cstZero);
isNegative =
rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(), isNegative);
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
binder.getLoc(), isNegative, rankVal);
Value dim = rewriter.create<Torch::AtenAddIntOp>(
Value axis = rewriter.create<Torch::AtenAddIntOp>(
binder.getLoc(), axisScalar, finalOffset);
Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());

Torch::BaseTensorType resultTensorType =
cast<Torch::BaseTensorType>(resultType);
if (!resultTensorType.hasDtype()) {
return rewriter.notifyMatchFailure(
binder.op, "expected result type to have a dtype");
Value res;
if (reverse) {
Value dims = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::IntType>()),
SmallVector<Value>{axis});
Value flip = rewriter.create<Torch::AtenFlipOp>(
binder.getLoc(), resultType, operand, dims);
Value cumsum = rewriter.create<Torch::AtenCumsumOp>(
binder.getLoc(), resultType, flip, axis, none);
res = rewriter.create<Torch::AtenFlipOp>(binder.getLoc(), resultType,
cumsum, dims);
} else {
res = rewriter.create<Torch::AtenCumsumOp>(
binder.getLoc(), resultType, operand, axis, none);
}
// resultTensorType.print(llvm::outs());
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
rewriter.replaceOpWithNewOp<Torch::AtenCumsumOp>(binder.op, resultType,
operand, dim, none);

if (exclusive)
res = rewriter.create<Torch::AtenSubTensorOp>(
binder.getLoc(), resultType, res, operand, cstOne);
rewriter.replaceOp(binder.op, res);
return success();
});
patterns.onOp(
Expand Down
19 changes: 14 additions & 5 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,20 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
if (gridShape[3] != 2)
return rewriter.notifyMatchFailure(binder.op,
"gridShape[3] expected to be 2");
std::string mode;
if (binder.customOpNameStringAttr(mode, "mode", "linear"))
std::string iModeString;
int64_t iModeInt;
if (binder.customOpNameStringAttr(iModeString, "mode", "linear"))
return rewriter.notifyMatchFailure(binder.op, "mode bind failure");
if (mode != "linear" && mode != "bilinear")

if (iModeString == "linear" || iModeString == "bilinear") {
iModeInt = 0;
} else if (iModeString == "nearest") {
iModeInt = 1;
} else {
return rewriter.notifyMatchFailure(
binder.op, "currently only mode : linear supported");
binder.op, "currently only mode : linear and nearest supported");
}

std::string padding;
if (binder.customOpNameStringAttr(padding, "padding_mode", "zeros"))
return rewriter.notifyMatchFailure(binder.op,
Expand All @@ -143,7 +151,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(

Value interpolationMode = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
rewriter.getIntegerAttr(rewriter.getIntegerType(64), iModeInt));

Value paddingMode = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
Expand Down
39 changes: 33 additions & 6 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2524,14 +2524,38 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern<AtenGridSamplerOp> {
Value result = b.create<tensor::ExtractOp>(loc, input, index);
return result;
};
auto lambdaInter = [&](OpBuilder &b, Location loc, Value x, Value y,
Value d) -> Value {

auto lambdaLinear = [&](OpBuilder &b, Location loc, Value x, Value y,
Value d) -> Value {
Value dm = b.create<arith::SubFOp>(loc, oneFloat, d);
Value ra = b.create<arith::MulFOp>(loc, x, dm);
Value rb = b.create<arith::MulFOp>(loc, y, d);
Value res = b.create<arith::AddFOp>(loc, ra, rb);
return res;
};

auto lambdaNearest = [&](OpBuilder &b, Location loc, Value x, Value y,
Value d) -> Value {
Value halfConst = rewriter.create<arith::ConstantOp>(
loc, rewriter.getFloatAttr(floatType, 0.5));
Value checkClosest =
b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT, d, halfConst);
Value res = b.create<arith::SelectOp>(loc, checkClosest, x, y);
return res;
};

auto lambdaInterpolate = [&](OpBuilder &b, Location loc, Value iMode,
Value x, Value y, Value d) -> Value {
Value linear = lambdaLinear(b, loc, x, y, d);
Value nearest = lambdaNearest(b, loc, x, y, d);
Value zeroInt =
b.create<arith::ConstantOp>(loc, b.getIntegerAttr(int64type, 0));
Value checkMode = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
iMode, zeroInt);
Value res = b.create<arith::SelectOp>(loc, checkMode, linear, nearest);
return res;
};

auto resultType = getTypeConverter()
->convertType(op.getResult().getType())
.cast<RankedTensorType>();
Expand All @@ -2545,6 +2569,7 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern<AtenGridSamplerOp> {
if (resultType.isDynamicDim(3))
resultSize.push_back(rewriter.create<tensor::DimOp>(loc, grid, 2));
Value alignCorners = adaptor.getAlignCorners();
Value interMode = adaptor.getInterpolationMode();
Value resultFinal =
rewriter.create<tensor::EmptyOp>(loc, resultType, resultSize);
auto sGrid = rewriter.create<linalg::GenericOp>(
Expand Down Expand Up @@ -2633,10 +2658,12 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern<AtenGridSamplerOp> {
Value lw1a = b.create<arith::SIToFPOp>(loc, floatType, lower1);
Value d1 = b.create<arith::SubFOp>(loc, result0, lw0a);
Value d0 = b.create<arith::SubFOp>(loc, result1, lw1a);
Value resultScaled0 = lambdaInter(b, loc, result00b, result01b, d0);
Value resultScaled1 = lambdaInter(b, loc, result10b, result11b, d0);
Value resultScaled =
lambdaInter(b, loc, resultScaled0, resultScaled1, d1);
Value resultScaled0 =
lambdaInterpolate(b, loc, interMode, result00b, result01b, d0);
Value resultScaled1 =
lambdaInterpolate(b, loc, interMode, result10b, result11b, d0);
Value resultScaled = lambdaInterpolate(
b, loc, interMode, resultScaled0, resultScaled1, d1);
b.create<linalg::YieldOp>(loc, resultScaled);
});
rewriter.replaceOp(op, sGrid.getResults());
Expand Down
38 changes: 38 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/gridsampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,41 @@ def GridSamplerBasic3_basic(module, tu: TestUtils):
[[[[-0.3498, -0.8196], [-0.2127, 0.2138], [-0.6515, -0.0513]]]]
).type(torch.FloatTensor)
module.forward(inp, grd)


class GridSamplerBasic4(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[None, ([1, 1, 4, 4], torch.float32, True), ([1, 1, 3, 2], torch.float32, True)]
)
def forward(self, x, g):
interpolation_mode = (1,)
padding_mode = (0,)
align_corners = (False,)
tRes = torch.ops.aten.grid_sampler(
x, g, interpolation_mode[0], padding_mode[0], align_corners[0]
)
return tRes


@register_test_case(module_factory=lambda: GridSamplerBasic4())
def GridSamplerBasic4_basic(module, tu: TestUtils):
inp = torch.tensor(
[
[
[
[0.4963, 0.7682, 0.0885, 0.1320],
[0.3074, 0.6341, 0.4901, 0.8964],
[0.4556, 0.6323, 0.3489, 0.4017],
[0.0223, 0.1689, 0.2939, 0.5185],
]
]
]
).type(torch.FloatTensor)
grd = torch.tensor(
[[[[-0.3498, -0.8196], [-0.2127, 0.2138], [-0.6515, -0.0513]]]]
).type(torch.FloatTensor)
module.forward(inp, grd)
79 changes: 79 additions & 0 deletions test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1324,6 +1324,85 @@ func.func @test_concat_3d_axis_negative_3(%arg0: !torch.vtensor<[2,2,2],f32>, %a

// -----

// CHECK-LABEL: @test_cumsum_1d_exclusive
func.func @test_cumsum_1d_exclusive(%arg0: !torch.vtensor<[5],f64>, %arg1: !torch.vtensor<[],si32>) -> !torch.vtensor<[5],f64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[RANK:.*]] = torch.constant.int 1
// CHECK: %[[C0:.*]] = torch.constant.int 0
// CHECK: %[[C1:.*]] = torch.constant.int 1
// CHECK: %[[AXIS:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si32> -> !torch.int
// CHECK: %[[BOOL:.*]] = torch.aten.lt.int %[[AXIS]], %[[C0]] : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[INT:.*]] = torch.aten.Int.bool %[[BOOL]] : !torch.bool -> !torch.int
// CHECK: %[[OTHER:.*]] = torch.aten.mul.int %[[INT]], %[[RANK]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[ADD:.*]] = torch.aten.add.int %[[AXIS]], %[[OTHER]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[CUMSUM:.*]] = torch.aten.cumsum %arg0, %[[ADD]], %[[NONE]] : !torch.vtensor<[5],f64>, !torch.int, !torch.none -> !torch.vtensor<[5],f64>
// CHECK: torch.aten.sub.Tensor %[[CUMSUM]], %arg0, %[[C1]] : !torch.vtensor<[5],f64>, !torch.vtensor<[5],f64>, !torch.int -> !torch.vtensor<[5],f64>
%0 = torch.operator "onnx.CumSum"(%arg0, %arg1) {torch.onnx.exclusive = 1 : si64} : (!torch.vtensor<[5],f64>, !torch.vtensor<[],si32>) -> !torch.vtensor<[5],f64>
return %0 : !torch.vtensor<[5],f64>
}

// -----

// CHECK-LABEL: @test_cumsum_1d_reverse
func.func @test_cumsum_1d_reverse(%arg0: !torch.vtensor<[5],f64>, %arg1: !torch.vtensor<[],si32>) -> !torch.vtensor<[5],f64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[RANK:.*]] = torch.constant.int 1
// CHECK: %[[C0:.*]] = torch.constant.int 0
// CHECK: %[[C1:.*]] = torch.constant.int 1
// CHECK: %[[AXIS:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si32> -> !torch.int
// CHECK: %[[BOOL:.*]] = torch.aten.lt.int %[[AXIS]], %[[C0]] : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[INT:.*]] = torch.aten.Int.bool %[[BOOL]] : !torch.bool -> !torch.int
// CHECK: %[[OTHER:.*]] = torch.aten.mul.int %[[INT]], %[[RANK]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[ADD:.*]] = torch.aten.add.int %[[AXIS]], %[[OTHER]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[ADD]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[FLIP:.*]] = torch.aten.flip %arg0, %[[DIMS]] : !torch.vtensor<[5],f64>, !torch.list<int> -> !torch.vtensor<[5],f64>
// CHECK: %[[CUMSUM:.*]] = torch.aten.cumsum %[[FLIP]], %[[ADD]], %[[NONE]] : !torch.vtensor<[5],f64>, !torch.int, !torch.none -> !torch.vtensor<[5],f64>
// CHECK: torch.aten.flip %[[CUMSUM]], %[[DIMS]] : !torch.vtensor<[5],f64>, !torch.list<int> -> !torch.vtensor<[5],f64>
%0 = torch.operator "onnx.CumSum"(%arg0, %arg1) {torch.onnx.reverse = 1 : si64} : (!torch.vtensor<[5],f64>, !torch.vtensor<[],si32>) -> !torch.vtensor<[5],f64>
return %0 : !torch.vtensor<[5],f64>
}

// -----

// CHECK-LABEL: @test_cumsum_1d_reverse_exclusive
func.func @test_cumsum_1d_reverse_exclusive(%arg0: !torch.vtensor<[5],f64>, %arg1: !torch.vtensor<[],si32>) -> !torch.vtensor<[5],f64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[RANK:.*]] = torch.constant.int 1
// CHECK: %[[C0:.*]] = torch.constant.int 0
// CHECK: %[[C1:.*]] = torch.constant.int 1
// CHECK: %[[AXIS:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si32> -> !torch.int
// CHECK: %[[BOOL:.*]] = torch.aten.lt.int %[[AXIS]], %[[C0]] : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[INT:.*]] = torch.aten.Int.bool %[[BOOL]] : !torch.bool -> !torch.int
// CHECK: %[[OTHER:.*]] = torch.aten.mul.int %[[INT]], %[[RANK]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[ADD:.*]] = torch.aten.add.int %[[AXIS]], %[[OTHER]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[ADD]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[FLIP:.*]] = torch.aten.flip %arg0, %[[DIMS]] : !torch.vtensor<[5],f64>, !torch.list<int> -> !torch.vtensor<[5],f64>
// CHECK: %[[CUMSUM:.*]] = torch.aten.cumsum %[[FLIP]], %[[ADD]], %[[NONE]] : !torch.vtensor<[5],f64>, !torch.int, !torch.none -> !torch.vtensor<[5],f64>
// CHECK: %[[FLIP_0:.*]] = torch.aten.flip %[[CUMSUM]], %[[DIMS]] : !torch.vtensor<[5],f64>, !torch.list<int> -> !torch.vtensor<[5],f64>
// CHECK: torch.aten.sub.Tensor %[[FLIP_0]], %arg0, %[[C1]] : !torch.vtensor<[5],f64>, !torch.vtensor<[5],f64>, !torch.int -> !torch.vtensor<[5],f64>
%0 = torch.operator "onnx.CumSum"(%arg0, %arg1) {torch.onnx.exclusive = 1 : si64, torch.onnx.reverse = 1 : si64} : (!torch.vtensor<[5],f64>, !torch.vtensor<[],si32>) -> !torch.vtensor<[5],f64>
return %0 : !torch.vtensor<[5],f64>
}

// -----

// CHECK-LABEL: @test_cumsum_2d
func.func @test_cumsum_2d(%arg0: !torch.vtensor<[2,3],f64>, %arg1: !torch.vtensor<[],si32>) -> !torch.vtensor<[2,3],f64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[RANK:.*]] = torch.constant.int 2
// CHECK: %[[C0:.*]] = torch.constant.int 0
// CHECK: %[[AXIS:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si32> -> !torch.int
// CHECK: %[[BOOL:.*]] = torch.aten.lt.int %[[AXIS]], %[[C0]] : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[INT:.*]] = torch.aten.Int.bool %[[BOOL]] : !torch.bool -> !torch.int
// CHECK: %[[OTHER:.*]] = torch.aten.mul.int %[[INT]], %[[RANK]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[ADD:.*]] = torch.aten.add.int %[[AXIS]], %[[OTHER]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[NONE:.*]] = torch.constant.none
// torch.aten.cumsum %arg0, %[[ADD]], %[[NONE]] : !torch.vtensor<[2,3],f64>, !torch.int, !torch.none -> !torch.vtensor<[2,3],f64>
%0 = torch.operator "onnx.CumSum"(%arg0, %arg1) : (!torch.vtensor<[2,3],f64>, !torch.vtensor<[],si32>) -> !torch.vtensor<[2,3],f64>
return %0 : !torch.vtensor<[2,3],f64>
}

// -----

// CHECK-LABEL: func.func @test_exp
func.func @test_exp(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64} {
// CHECK: torch.aten.exp %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32>
Expand Down
Loading

0 comments on commit f5ff46e

Please sign in to comment.