Skip to content

Commit

Permalink
[AutoBump] Merge with 1a9c0a3 (Jun 07)
Browse files Browse the repository at this point in the history
  • Loading branch information
mgehre-amd committed Sep 9, 2024
2 parents 813abc3 + 1a9c0a3 commit 7c5a142
Show file tree
Hide file tree
Showing 30 changed files with 1,892 additions and 168 deletions.
56 changes: 54 additions & 2 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -6819,6 +6819,31 @@ def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
}];
}

def Torch_AtenMaxUnpool2dOp : Torch_Op<"aten.max_unpool2d", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::max_unpool2d : (Tensor, Tensor, int[]) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$indices,
AnyTorchListOfTorchIntType:$output_size
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenMaxUnpool2dOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenMaxUnpool2dOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}

def Torch_AtenMaxPool2dWithIndicesOp : Torch_Op<"aten.max_pool2d_with_indices", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down Expand Up @@ -6907,6 +6932,33 @@ def Torch_AtenMaxPool3dOp : Torch_Op<"aten.max_pool3d", [
}];
}

def Torch_AtenMaxUnpool3dOp : Torch_Op<"aten.max_unpool3d", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::max_unpool3d : (Tensor, Tensor, int[], int[], int[]) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$indices,
AnyTorchListOfTorchIntType:$output_size,
AnyTorchListOfTorchIntType:$stride,
AnyTorchListOfTorchIntType:$padding
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenMaxUnpool3dOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 5, 1);
}
void AtenMaxUnpool3dOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 5, 1);
}
}];
}

def Torch_AtenMaxPool3dWithIndicesOp : Torch_Op<"aten.max_pool3d_with_indices", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down Expand Up @@ -16250,11 +16302,11 @@ def Torch_PrimsVarOp : Torch_Op<"prims.var", [
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `prims::var : (Tensor, int[]?, float, int?) -> (Tensor)`";
let summary = "Generated op for `prims::var : (Tensor, int[]?, float?, int?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$inp,
AnyTorchOptionalListOfTorchIntType:$dims,
Torch_FloatType:$correction,
AnyTorchOptionalFloatType:$correction,
AnyTorchOptionalIntType:$output_dtype
);
let results = (outs
Expand Down
64 changes: 64 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/TorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#define TORCH_OPS

include "torch-mlir/Dialect/Torch/IR/TorchTypes.td"
include "mlir/IR/BuiltinAttributes.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/CastInterfaces.td"
Expand Down Expand Up @@ -1337,4 +1338,67 @@ def Torch_DtypeCalculateYieldDtypesOp : Torch_Op<"dtype.calculate.yield.dtypes",
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// Symbolic shape modeling ops for TorchDynamo frontend.
//===----------------------------------------------------------------------===//

def Torch_SymbolicIntOp : Torch_Op<"symbolic_int", [Pure]> {
let summary = "Symbolic int representing a dynamic dimension";
let description = [{
The `torch.symbolic_int` operation captures a dynamic dimension on the
global function arguments as exported by TorchDynamo (torch.export).
It associates the shape symbols (i.e. "s0", "s1") with the
global SSA values (i.e. `%0`, `%1`) that is then referenced
to bind shapes on op results.

Additionally, the operation annotates `min_val` and `max_val` attributes
denoting the range constraints for the dynamic dimension. This may be
useful for modeling runtime shape guards, or compile-time optimizations
based on the shape bounds (min, opt, max) on results of ops / regions.

Example:
```
%0 = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int
%1 = torch.symbolic_int "s1" {min_val = 2, max_val = 20} : !torch.int
```
}];
let arguments = (ins
StrAttr:$symbol_name,
I64Attr:$min_val,
I64Attr:$max_val
);
let results = (outs
Torch_IntType:$result
);
let assemblyFormat = [{
$symbol_name ` ` `{` `min_val` `=` $min_val `,` `max_val` `=` $max_val `}` attr-dict `:` type($result)
}];
}

def Torch_BindSymbolicShapeOp : Torch_Op<"bind_symbolic_shape", []> {
let summary = "Binds shape expressions to tensors using an affine map indexed by shape symbols";
let description = [{
The `torch.bind_symbolic_shape` operation binds shape expressions
useful to compute the dynamic dimensions of a tensor. It takes a
variadic of SSA symbols that map 1:1 to the local symbols declared
in the affine map. The affine map contains a list of affine shape
expressions for each dim where the terminals are from the declared
symbols.

Example:
```
torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
torch.bind_symbolic_shape %out0, [%0, %1, %2], affine_map<()[s0, s1, s2] -> (s0, s1 * 2 + s2, 3)> : !torch.vtensor<[?,?,3],f32>
```
}];
let arguments = (ins
Torch_ValueTensorType:$operand,
Variadic<Torch_IntType>:$shape_symbols,
Builtin_AffineMapAttr:$shape_expressions
);
let results = (outs);
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

#endif // TORCH_OPS
94 changes: 80 additions & 14 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -951,7 +951,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
return rewriter.notifyMatchFailure(
binder.op, "unsupported conversion: auto_pad != NOTSET");
}

Torch::ValueTensorType resultType;
Value input, weight;
int64_t group;
Expand Down Expand Up @@ -1034,23 +1033,94 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(

SmallVector<Value> cstPadding, cstStrides, cstDilations,
cstOutputPadding;
Value paddedInput = input;
Value paddingList;
if (padding.size() != 2 * (rank - 2)) {
for (int64_t i : padding) {
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
}
paddingList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(
Torch::IntType::get(binder.op->getContext())),
cstPadding);
} else {
// ONNX offers pads in the format listing all starting dims, then all
// ending dims, e.g. {t, l, b, r} for conv2d. Torch by default accepts
// only starting dims, e.g. {t, l}. However, we can support padding at
// the beginning and end of each dimension by first performing
// torch.nn.functional.pad on the input. But this requires the pad
// values to be rearranged since torch pad() takes pads in the order
// rightmost dim start and end, then next to last, and so on, e.g. {l,
// r, t, b}.
bool matchedPads = true;
for (unsigned i = 0; i < padding.size() / 2; i++) {
if (padding[i] != padding[i + (padding.size() / 2)]) {
// TODO: Add support for different padding values for the
// beginning and ending along each spatial axis
return rewriter.notifyMatchFailure(
binder.op,
"unsupported conversion: padding values for the beginning "
"and ending along each spatial axis must be equal");
matchedPads = false;
break;
}
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(padding[i])));
}
if (matchedPads) {
for (unsigned i = 0; i < padding.size() / 2; i++) {
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(padding[i])));
}
paddingList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(
Torch::IntType::get(binder.op->getContext())),
cstPadding);
} else {
SmallVector<Value> padsRearrange;
SmallVector<Value> inputPaddingList;
for (uint32_t i = 0; i < padding.size() / 2; i++) {
padsRearrange.emplace_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(padding[i])));
padsRearrange.emplace_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(
padding[(padding.size() / 2) + i])));
inputPaddingList.emplace_back(
rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0)));
}
// The conv op itself will have no padding since the actual padding
// is performed using the torch.pad preceding it.
paddingList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(
Torch::IntType::get(binder.op->getContext())),
inputPaddingList);
Value padsSizeList =
rewriter
.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(
rewriter.getType<Torch::IntType>()),
padsRearrange)
.getResult();
Value modeVal = rewriter.create<Torch::ConstantStrOp>(
binder.getLoc(), rewriter.getStringAttr("constant"));
Value constantValue;
auto inputTensorType =
cast<Torch::ValueTensorType>(input.getType());
if (isa<IntegerType>(inputTensorType.getDtype()))
constantValue = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0));
if (isa<FloatType>(inputTensorType.getDtype()))
constantValue = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getF64FloatAttr(0.0f));
// Pad output shape must be computed explicitly from the pad values
SmallVector<int64_t> newInputShape(inputTensorType.getSizes());
for (uint32_t i = 0; i < padding.size() / 2; i++) {
newInputShape[2 + i] +=
padding[i] + padding[(padding.size() / 2) + i];
}
auto padTy = rewriter.getType<Torch::ValueTensorType>(
newInputShape, inputTensorType.getDtype());
paddedInput = rewriter.create<Torch::AtenPadOp>(
binder.getLoc(), padTy, input, padsSizeList, modeVal,
constantValue);
}
}
for (int64_t i : dilations) {
Expand All @@ -1065,10 +1135,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.getLoc(), rewriter.getI64IntegerAttr(0));
cstOutputPadding = {cstZero, cstZero};

Value paddingList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
cstPadding);
Value dilationsList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
Expand All @@ -1095,7 +1161,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.getLoc(), rewriter.getI64IntegerAttr(group));

rewriter.replaceOpWithNewOp<Torch::AtenConvolutionOp>(
binder.op, resultType, input, weight, bias, stridesList,
binder.op, resultType, paddedInput, weight, bias, stridesList,
paddingList, dilationsList, transposed, outputPaddingList,
cstGroup);
return success();
Expand Down
78 changes: 78 additions & 0 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1918,4 +1918,82 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(

return success();
});
patterns.onOp(
"MaxUnpool", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
// TODO: Add support for `output_shape` arg.
if (binder.op->getNumOperands() == 3)
return rewriter.notifyMatchFailure(
binder.op, "unimplemented: output_shape arg is not supported");

Torch::ValueTensorType resultType;
Value data, indices;
if (binder.tensorOperandAtIndex(data, 0) ||
binder.tensorOperandAtIndex(indices, 1) ||
binder.tensorResultType(resultType))
return rewriter.notifyMatchFailure(
binder.op, "data/indices/resultType bind failure");
std::optional<unsigned> maybeRank = Torch::getTensorRank(data);
if (!maybeRank)
return rewriter.notifyMatchFailure(binder.op,
"Unimplemented: unranked tensor");
int64_t rank = *maybeRank;
int64_t spatial = rank - 2;

if (rank <= 3 || rank > 5)
return rewriter.notifyMatchFailure(binder.op,
"Unimplemented: MaxUnpool support "
"only present for rank 4/5 input");

if (!(resultType.hasSizes() && resultType.areAllSizesKnown()))
return rewriter.notifyMatchFailure(
binder.op, "unimplemented: expected result to have all shapes "
"statically known");

SmallVector<int64_t> resultShape(resultType.getSizes());
Value resultShapeList =
createConstantIntList(binder, rewriter, resultShape);
if (rank == 4) {
rewriter.replaceOpWithNewOp<Torch::AtenMaxUnpool2dOp>(
binder.op, resultType, data, indices, resultShapeList);
return success();
}

SmallVector<int64_t> padding, strides;
if (binder.s64IntegerArrayAttr(padding, "pads", {}))
return rewriter.notifyMatchFailure(binder.op, "pads bind failure");
if (!padding.empty() &&
padding.size() != static_cast<size_t>(2 * spatial))
return rewriter.notifyMatchFailure(
binder.op, "padding list must contain (begin,end) pair for each "
"spatial axis");
if (binder.s64IntegerArrayAttr(strides, "strides", {}))
return rewriter.notifyMatchFailure(binder.op, "strides bind failure");
if (!strides.empty() && strides.size() != static_cast<size_t>(spatial))
return rewriter.notifyMatchFailure(
binder.op, "strides list size does not match the number of axes");

if (padding.empty())
padding.resize(spatial, 0);
if (strides.empty())
strides.resize(spatial, 1);

// If the padding is symmetric we can push the padding
// operation to the torch operator.
if (padding.size() == static_cast<size_t>(2 * spatial)) {
bool equal = true;
for (int i = 0; i < spatial; ++i) {
equal = equal && (padding[i] == padding[i + spatial]);
}
if (equal)
padding.resize(spatial);
}

Value paddingList = createConstantIntList(binder, rewriter, padding);
Value stridesList = createConstantIntList(binder, rewriter, strides);

rewriter.replaceOpWithNewOp<Torch::AtenMaxUnpool3dOp>(
binder.op, resultType, data, indices, resultShapeList, stridesList,
paddingList);
return success();
});
}
Loading

0 comments on commit 7c5a142

Please sign in to comment.