Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoBump] Merge with fixes of 11cd7cd9 (23) #254

Merged
merged 3 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -6720,6 +6720,7 @@ def Torch_AtenMaxPool2dWithIndicesOp : Torch_Op<"aten.max_pool2d_with_indices",
printDefaultTorchOp(printer, *this, 6, 2);
}
}];
let hasCanonicalizer = 1;
}

def Torch_AtenMaxPool2dWithIndicesBackwardOp : Torch_Op<"aten.max_pool2d_with_indices_backward", [
Expand Down Expand Up @@ -16090,6 +16091,7 @@ def Torch_PrimsConvertElementTypeOp : Torch_Op<"prims.convert_element_type", [
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasFolder = 1;
}

def Torch_PrimsVarOp : Torch_Op<"prims.var", [
Expand Down
39 changes: 39 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4799,6 +4799,45 @@ LogicalResult AtenPermuteOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// PrimsConvertElementTypeOp
//===----------------------------------------------------------------------===//

OpFoldResult PrimsConvertElementTypeOp::fold(FoldAdaptor adaptor) {
auto inputType = cast<BaseTensorType>(getA().getType());
auto outputType = cast<BaseTensorType>(getResult().getType());
if (inputType != outputType)
return nullptr;
if (!inputType.hasDtype() || !outputType.hasDtype())
return nullptr;
if (inputType.getDtype() != outputType.getDtype())
return nullptr;
return getA();
}

//===----------------------------------------------------------------------===//
// AtenMaxPool2dWithIndicesOp
//===----------------------------------------------------------------------===//

void AtenMaxPool2dWithIndicesOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add(+[](AtenMaxPool2dWithIndicesOp op, PatternRewriter &rewriter) {
if (!op.getResult1().use_empty()) {
return rewriter.notifyMatchFailure(
op, "result1 of MaxPool2dWithIndices should be unused");
}

Value result = rewriter.create<Torch::AtenMaxPool2dOp>(
op->getLoc(), op.getResult0().getType(), op.getSelf(),
op.getKernelSize(), op.getStride(), op.getPadding(), op.getDilation(),
op.getCeilMode());

op.getResult0().replaceAllUsesWith(result);
rewriter.eraseOp(op);
return success();
});
}

//===----------------------------------------------------------------------===//
// AtenLinalgCrossOp
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,8 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::max_pool1d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
emit(
"aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)"
"aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)",
has_canonicalizer=True,
)
emit(
"aten::max_pool2d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)"
Expand Down Expand Up @@ -1120,7 +1121,7 @@ def emit_with_mutating_variants(key, **kwargs):
# `prims::` namespace.
# ==========================================================================

emit("prims::convert_element_type : (Tensor, int) -> (Tensor)")
emit("prims::convert_element_type : (Tensor, int) -> (Tensor)", has_folder=True)
emit("prims::var : (Tensor, int[]?, float, int?) -> (Tensor)")
emit("prims::sqrt : (Tensor) -> (Tensor)")
emit("prims::collapse : (Tensor, int, int) -> (Tensor)")
Expand Down
41 changes: 41 additions & 0 deletions test/Dialect/Torch/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2974,3 +2974,44 @@ func.func @aten_log$fold_splat_f32() -> !torch.vtensor<[4], f32> {
%result = torch.aten.log %cst : !torch.vtensor<[4], f32> -> !torch.vtensor<[4], f32>
return %result : !torch.vtensor<[4], f32>
}

// -----

// CHECK-LABEL: func.func @torch.prims.convert_element_type$fold(
// CHECK: %[[ARG:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[64],f32> {
// CHECK: return %[[ARG]] : !torch.vtensor<[64],f32>
func.func @torch.prims.convert_element_type$fold(%arg0: !torch.vtensor<[64],f32>) -> !torch.vtensor<[64],f32> {
%int6 = torch.constant.int 6
%0 = torch.prims.convert_element_type %arg0, %int6 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32>
return %0 : !torch.vtensor<[64],f32>
}

// -----

// CHECK-LABEL: func.func @torch.prims.convert_element_type$no_fold(
// CHECK: %[[ARG:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[64],si32> {
// CHECK: %[[RET:.*]] = torch.prims.convert_element_type %[[ARG]], %{{.*}} : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],si32>
// CHECK: return %[[RET]] : !torch.vtensor<[64],si32>
func.func @torch.prims.convert_element_type$no_fold(%arg0: !torch.vtensor<[64],f32>) -> !torch.vtensor<[64],si32> {
%int6 = torch.constant.int 6
%0 = torch.prims.convert_element_type %arg0, %int6 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],si32>
return %0 : !torch.vtensor<[64],si32>
}

// -----

// CHECK-LABEL: @torch.aten.max_pool2d_with_indices$canonicalize(
// CHECK: %[[ARG:.*]]: !torch.vtensor<[10,64,112,112],f32>) -> !torch.vtensor<[10,64,56,56],f32> {
// CHECK: %[[RET:.*]] = torch.aten.max_pool2d %[[ARG]]
// CHECK: return %[[RET]] : !torch.vtensor<[10,64,56,56],f32>
func.func @torch.aten.max_pool2d_with_indices$canonicalize(%arg0: !torch.vtensor<[10,64,112,112],f32>) -> !torch.vtensor<[10,64,56,56],f32> {
%false = torch.constant.bool false
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%29 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
%30 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
%31 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%result0, %result1 = torch.aten.max_pool2d_with_indices %arg0, %29, %30, %31, %31, %false : !torch.vtensor<[10,64,112,112],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[10,64,56,56],f32>, !torch.vtensor<[10,64,56,56],si64>
return %result0 : !torch.vtensor<[10,64,56,56],f32>
}
Loading