diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 87d2e884c4fd..3b01c79b9eed 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -11506,6 +11506,7 @@ def Torch_AtenReshapeOp : Torch_Op<"aten.reshape", [ printDefaultTorchOp(printer, *this, 2, 1); } }]; + let hasFolder = 1; } def Torch_AtenReshapeAsOp : Torch_Op<"aten.reshape_as", [ diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index a7ce975b0e26..11f021d86ceb 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -2261,6 +2261,19 @@ void AtenUnflattenIntOp::getCanonicalizationPatterns( }); } +//===----------------------------------------------------------------------===// +// AtenReshapeOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenReshapeOp::fold(FoldAdaptor adaptor) { + auto selfTy = dyn_cast(getSelf().getType()); + auto opTy = dyn_cast(getType()); + if (selfTy && selfTy == opTy && selfTy.hasSizes() && + selfTy.toBuiltinTensor().hasStaticShape()) + return getSelf(); + return nullptr; +} + //===----------------------------------------------------------------------===// // AtenSelectIntOp //===----------------------------------------------------------------------===// diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 8276af7ccfb6..82c882ea79ee 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -865,7 +865,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::repeat_interleave.Tensor : (Tensor, int?) -> (Tensor)") emit("aten::repeat_interleave.self_int : (Tensor, int, int?, int?) -> (Tensor)") emit("aten::tile : (Tensor, int[]) -> (Tensor)") - emit("aten::reshape : (Tensor, int[]) -> (Tensor)") + emit("aten::reshape : (Tensor, int[]) -> (Tensor)", has_folder=True) emit("aten::reshape_as : (Tensor, Tensor) -> (Tensor)") emit("aten::_reshape_alias : (Tensor, int[], int[]) -> (Tensor)") emit("aten::resize : (Tensor, int[], int?) -> (Tensor)")