Skip to content

Commit

Permalink
Merge branch 'feature/backport_ea1_ops' into tina.error-out-onnx-op-v…
Browse files Browse the repository at this point in the history
…ersion-type-mismatch
  • Loading branch information
TinaAMD authored Jul 10, 2024
2 parents 9e0bc40 + 2f137b6 commit d212544
Show file tree
Hide file tree
Showing 17 changed files with 129 additions and 136 deletions.
15 changes: 15 additions & 0 deletions .github/dependabot.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# To get started with Dependabot version updates, you'll need to specify which
# package ecosystems to update and where the package manifests are located.
# Please see the documentation for all configuration options:
# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates

version: 2
updates:
- package-ecosystem: "gitsubmodule"
directory: "/"
allow:
- dependency-name: "externals/llvm-project"
schedule:
interval: "daily"
time: "06:00"
timezone: "Europe/Berlin"
28 changes: 28 additions & 0 deletions .github/workflows/approve_dependabot.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
name: Dependabot auto-approve & auto-merge
on: pull_request

permissions:
pull-requests: write
# Needed to enable auto-merge
contents: write

jobs:
dependabot:
runs-on: ubuntu-latest
if: github.actor == 'dependabot[bot]'
steps:
- name: Dependabot metadata
id: metadata
uses: dependabot/fetch-metadata@v2
with:
github-token: "${{ secrets.GITHUB_TOKEN }}"
- name: Approve a PR
run: gh pr review --approve "$PR_URL"
env:
PR_URL: ${{github.event.pull_request.html_url}}
GH_TOKEN: ${{secrets.GITHUB_TOKEN}}
- name: Enable auto-merge for Dependabot PRs
run: gh pr merge --auto --merge "$PR_URL"
env:
PR_URL: ${{github.event.pull_request.html_url}}
GH_TOKEN: ${{secrets.GITHUB_TOKEN}}
2 changes: 1 addition & 1 deletion externals/llvm-project
Submodule llvm-project updated 75 files
+4 −0 mlir/include/mlir-c/IR.h
+2 −2 mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h
+3 −1 mlir/include/mlir/Conversion/FuncToEmitC/FuncToEmitC.h
+3 −1 mlir/include/mlir/Conversion/SCFToEmitC/SCFToEmitC.h
+3 −1 mlir/include/mlir/Conversion/TosaToArith/TosaToArith.h
+8 −0 mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+3 −0 mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+4 −0 mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
+4 −2 mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+9 −0 mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td
+17 −0 mlir/include/mlir/Dialect/EmitC/Transforms/TypeConversions.h
+28 −14 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+3 −4 mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
+3 −0 mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+12 −2 mlir/include/mlir/IR/PatternMatch.h
+4 −1 mlir/include/mlir/Tools/PDLL/Parser/Parser.h
+14 −7 mlir/lib/Bindings/Python/IRCore.cpp
+3 −2 mlir/lib/Bindings/Python/IRModule.h
+3 −0 mlir/lib/CAPI/IR/IR.cpp
+299 −72 mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+6 −3 mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp
+1 −0 mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt
+37 −7 mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp
+9 −1 mlir/lib/Conversion/FuncToEmitC/FuncToEmitCPass.cpp
+5 −3 mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
+62 −36 mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
+18 −7 mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
+5 −1 mlir/lib/Conversion/TosaToArith/TosaToArithPass.cpp
+1 −1 mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+1 −0 mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+11 −4 mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+28 −9 mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+1 −0 mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt
+39 −0 mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
+12 −6 mlir/lib/Dialect/PDL/IR/Builtins.cpp
+9 −6 mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+266 −56 mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
+7 −2 mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp
+4 −0 mlir/lib/Target/Cpp/TranslateToCpp.cpp
+21 −0 mlir/lib/Tools/PDLL/Parser/Lexer.cpp
+14 −0 mlir/lib/Tools/PDLL/Parser/Lexer.h
+67 −3 mlir/lib/Tools/PDLL/Parser/Parser.cpp
+4 −1 mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp
+3 −0 mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+15 −3 mlir/test/CAPI/ir.c
+46 −10 mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
+299 −16 mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+30 −0 mlir/test/Conversion/FuncToEmitC/func-to-emitc.mlir
+32 −4 mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
+16 −4 mlir/test/Conversion/SCFToEmitC/for.mlir
+33 −0 mlir/test/Conversion/SCFToEmitC/nest-for-if.mlir
+9 −5 mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir
+11 −0 mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+15 −0 mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-bodiless-functions-results.mlir
+4 −4 mlir/test/Dialect/EmitC/invalid_ops.mlir
+2 −2 mlir/test/Dialect/EmitC/invalid_types.mlir
+4 −4 mlir/test/Dialect/EmitC/ops.mlir
+10 −0 mlir/test/Dialect/EmitC/types.mlir
+46 −0 mlir/test/Dialect/Tosa/canonicalize_with_debuginfo.mlir
+75 −0 mlir/test/Dialect/Tosa/constant-cos.mlir
+131 −0 mlir/test/Dialect/Tosa/constant-floor.mlir
+194 −26 mlir/test/Dialect/Tosa/constant-op-fold.mlir
+75 −0 mlir/test/Dialect/Tosa/constant-sin.mlir
+145 −0 mlir/test/Dialect/Tosa/constant-tile.mlir
+20 −20 mlir/test/Target/Cpp/for.mlir
+10 −0 mlir/test/Target/Cpp/types.mlir
+11 −0 mlir/test/mlir-pdll/Parser/include-file-and-file-which-includes-first-file-with-once-00.pdll
+11 −0 mlir/test/mlir-pdll/Parser/include-file-and-file-which-includes-first-file-with-once-01.pdll
+9 −0 mlir/test/mlir-pdll/Parser/include-file-twice-with-once.pdll
+5 −0 mlir/test/mlir-pdll/Parser/include-file-twice-without-once.pdll
+5 −0 mlir/test/mlir-pdll/Parser/include/include-file-with-include.pdll
+5 −0 mlir/test/mlir-pdll/Parser/include/included-with-once.pdll
+8 −1 mlir/test/python/ir/operation.py
+11 −3 mlir/tools/mlir-pdll/mlir-pdll.cpp
+7 −2 mlir/unittests/Dialect/PDL/BuiltinTest.cpp
2 changes: 1 addition & 1 deletion lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {

Type intType = IntegerType::get(context, 64);
auto castIndexToInt = [&](Value v) {
return rewriter.create<arith::IndexCastOp>(loc, intType, v);
return rewriter.createOrFold<arith::IndexCastOp>(loc, intType, v);
};

SmallVector<Value> paddingIntValues;
Expand Down
32 changes: 13 additions & 19 deletions lib/Conversion/TorchToLinalg/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,10 @@ Value torch_to_linalg::getDynamicZeroPaddedTensor(
pad < paddingIncludingUnchanged.end(); pad++)
*pad = castIntToIndex(b, loc, *pad);

Type elementType = input.getType().cast<RankedTensorType>().getElementType();
// TODO: audit possibility of sparsity on this tensor
Type inputType =
RankedTensorType::get(makeShapeLLVMCompatible(llvm::ArrayRef<int64_t>(
SmallVector<int64_t>(inRank, kUnknownSize))),
elementType);

SmallVector<OpFoldResult> paddingValues =
getAsOpFoldResult(paddingIncludingUnchanged);
return b.create<tensor::PadOp>(loc, inputType, input, /*low=*/paddingValues,

return b.create<tensor::PadOp>(loc, Type{}, input, /*low=*/paddingValues,
/*high=*/paddingValues, pad);
}

Expand All @@ -107,25 +101,25 @@ Value torch_to_linalg::getOutputDimForConvOps(OpBuilder &b, Location loc,
Value c1 = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(1));
Value c2 = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(2));

Value doublePadding = b.create<arith::MulIOp>(loc, paddingInt, c2);
Value doublePadding = b.createOrFold<arith::MulIOp>(loc, paddingInt, c2);
// in + 2 * padding
Value inAddDoublePadding =
b.create<arith::AddIOp>(loc, castIndexToInt64(b, loc, in), doublePadding);
Value inAddDoublePadding = b.createOrFold<arith::AddIOp>(
loc, castIndexToInt64(b, loc, in), doublePadding);

// dilation * (kernelSize - 1)
Value kernelSizeSub1 = b.create<arith::SubIOp>(loc, kernelSizeInt, c1);
Value kernelSizeSub1 = b.createOrFold<arith::SubIOp>(loc, kernelSizeInt, c1);
Value dilationTimesKernelSize =
b.create<arith::MulIOp>(loc, dilationInt, kernelSizeSub1);
b.createOrFold<arith::MulIOp>(loc, dilationInt, kernelSizeSub1);

Value temp =
b.create<arith::SubIOp>(loc, inAddDoublePadding, dilationTimesKernelSize);
Value dividend = b.create<arith::SubIOp>(loc, temp, c1);
Value temp = b.createOrFold<arith::SubIOp>(loc, inAddDoublePadding,
dilationTimesKernelSize);
Value dividend = b.createOrFold<arith::SubIOp>(loc, temp, c1);
Value division;
if (ceilMode)
division = b.create<arith::CeilDivSIOp>(loc, dividend, strideInt);
division = b.createOrFold<arith::CeilDivSIOp>(loc, dividend, strideInt);
else
division = b.create<arith::FloorDivSIOp>(loc, dividend, strideInt);
Value out = b.create<arith::AddIOp>(loc, division, c1);
division = b.createOrFold<arith::FloorDivSIOp>(loc, dividend, strideInt);
Value out = b.createOrFold<arith::AddIOp>(loc, division, c1);
return castIntToIndex(b, loc, out);
}

Expand Down
8 changes: 4 additions & 4 deletions lib/Conversion/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,13 @@ Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
}

Value castIntToIndex(OpBuilder &b, Location loc, Value v) {
assert(v.getType().isa<IntegerType>() && "must be called with integer type");
return b.create<arith::IndexCastOp>(loc, b.getIndexType(), v);
assert(isa<IntegerType>(v.getType()) && "must be called with integer type");
return b.createOrFold<arith::IndexCastOp>(loc, b.getIndexType(), v);
}

Value castIndexToInt64(OpBuilder &b, Location loc, Value idx) {
assert(idx.getType().isa<IndexType>() && "must be called with integer type");
return b.create<arith::IndexCastOp>(loc, b.getI64Type(), idx);
assert(isa<IndexType>(idx.getType()) && "must be called with integer type");
return b.createOrFold<arith::IndexCastOp>(loc, b.getI64Type(), idx);
}

SmallVector<Value>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ static void setupTorchIntToI64Conversion(ConversionTarget &target,
if (!inputs[0].getType().isa<Torch::IntType>())
return std::nullopt;
assert(inputs.size() == 1);
return builder.create<ToI64Op>(loc, inputs[0]).getResult();
return builder.createOrFold<ToI64Op>(loc, inputs[0]);
});
auto sourceMaterialization = [](OpBuilder &builder, Torch::IntType type,
ValueRange inputs, Location loc) -> Value {
Expand Down
7 changes: 0 additions & 7 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
from torch_mlir_e2e_test.test_suite import COMMON_TORCH_MLIR_LOWERING_XFAILS
from torch_mlir._version import torch_version_for_comparison, version

print(f"TORCH_VERSION_FOR_COMPARISON =", torch_version_for_comparison())

LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
"Conv1dNoPaddingGroupModule_basic",
"RepeatInterleaveStaticModule_basic",
Expand Down Expand Up @@ -352,11 +350,6 @@
"InterpolateDynamicModule_scales_recompute_bilinear",
}

if torch_version_for_comparison() <= version.parse("2.2.0"):
TORCHDYNAMO_XFAIL_SET |= {
'OneHotModule_basic',
}

TORCHDYNAMO_CRASHING_SET = {
# No upstream decompositions.
# %6:4 = torch.operator "aten._embedding_bag_forward_only"(%1, %3, %5, %false, %int0, %false, %none, %false, %int-1) : (!torch.tensor<*,f32>, !torch.tensor<*,si64>, !torch.tensor<*,si64>, !torch.bool, !torch.int, !torch.bool, !torch.none, !torch.bool, !torch.int) -> (!torch.tensor, !torch.tensor, !torch.tensor, !torch.tensor)
Expand Down
4 changes: 1 addition & 3 deletions projects/pt1/python/torch_mlir/dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def _get_decomposition_table():
# (the upstream decomposition we use here does), even though we have
# support for aten.native_batch_norm_backward.
aten._native_batch_norm_legit_functional,
aten._native_batch_norm_legit_no_training,
aten.native_group_norm,
aten.split.Tensor,
aten.split_with_sizes,
Expand All @@ -67,9 +68,6 @@ def _get_decomposition_table():
aten.cumsum,
aten.index_select,
]
# TODO: enable test once 2.1.0 is stable
if torch_version_for_comparison() >= version.parse("2.1.0.dev"):
decomp_list += [aten._native_batch_norm_legit_no_training]
return get_decompositions(decomp_list)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,6 @@
"ElementwiseToDtypeI64ToUI8Module_basic",
}

# TODO: Delete once torch 2.1.0 is released
if torch_version_for_comparison() < version.parse("2.1.0.dev"):
COMMON_TORCH_MLIR_LOWERING_XFAILS.update({
"ScaledDotProductAttentionDifferentModule_basic",
"ScaledDotProductAttentionSameModule_basic"
})


def register_all_tests():
"""Registers all the built-in E2E tests that Torch-MLIR provides."""
# Side-effecting import statements.
Expand Down
4 changes: 2 additions & 2 deletions stable-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
--index-url https://download.pytorch.org/whl/cpu
torch==2.1.2+cpu
torchvision==0.16.2+cpu
torch==2.3.1+cpu
torchvision==0.18.1+cpu
2 changes: 1 addition & 1 deletion test/Conversion/TorchToLinalg/elementwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func.func @elementwise$ternary(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],f32> {
// CHECK: %[[C1:.*]] = torch.constant.int 1
// CHECK: %[[BUILTIN_C1:.*]] = torch_c.to_i64 %[[C1]]
// CHECK: %[[BUILTIN_C1:.*]] = arith.constant 1 : i64
// CHECK: linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>]
// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32, %{{.*}}: f32):
// CHECK: %[[ALPHA:.*]] = arith.sitofp %[[BUILTIN_C1]] : i64 to f32
Expand Down
14 changes: 7 additions & 7 deletions test/Conversion/TorchToLinalg/pooling.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ func.func @forward_max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vt
%int7 = torch.constant.int 7
%int8 = torch.constant.int 8
%false = torch.constant.bool false
// CHECK: %[[C1:.*]] = torch_c.to_i64 %int1
// CHECK: %[[C2:.*]] = torch_c.to_i64 %int2
// CHECK: %[[C1:.*]] = arith.constant 1 : i64
// CHECK: %[[C2:.*]] = arith.constant 2 : i64
// CHECK: %[[NEUTRAL:.*]] = arith.constant 0xFF800000 : f32
// CHECK: %[[PADDED:.*]] = tensor.pad %{{.*}} low[0, 0, 5, 6] high[0, 0, 5, 6]
// CHECK: %[[OUT:.*]] = linalg.fill ins(%[[NEUTRAL]] : f32) outs(%{{.*}} : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[T1:.*]] = arith.index_cast %[[C1]] : i64 to index
// CHECK: %[[T2:.*]] = arith.index_cast %[[C2]] : i64 to index
// CHECK: %[[INIT:.*]] = tensor.empty(%[[T1]], %[[T2]]) : tensor<?x?xf32>
// CHECK: linalg.pooling_nchw_max {dilations = dense<[7, 8]> : vector<2xi64>, strides = dense<[3, 4]> : vector<2xi64>} ins(%[[PADDED]], %[[INIT]] : tensor<?x?x?x?xf32>, tensor<?x?xf32>) outs(%[[OUT]] : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[T1:.*]] = arith.constant 1 : index
// CHECK: %[[T2:.*]] = arith.constant 2 : index
// CHECK: %[[INIT:.*]] = tensor.empty() : tensor<1x2xf32>
// CHECK: linalg.pooling_nchw_max {dilations = dense<[7, 8]> : vector<2xi64>, strides = dense<[3, 4]> : vector<2xi64>} ins(%[[PADDED]], %[[INIT]] : tensor<?x?x?x?xf32>, tensor<1x2xf32>) outs(%[[OUT]] : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
%kernel_size = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
%stride = torch.prim.ListConstruct %int3, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
%padding = torch.prim.ListConstruct %int5, %int6 : (!torch.int, !torch.int) -> !torch.list<int>
Expand Down Expand Up @@ -66,7 +66,7 @@ func.func @forward_max_pool3d(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch.
// CHECK: } : tensor<?x?x?x?x?xf32> to tensor<?x?x?x?x?xf32>

// CHECK: %[[OUTPUT_TENSOR:.*]] = linalg.fill ins(%[[MIN_VALUE:.*]] : f32) outs(%{{.*}} : tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
// CHECK: %[[MAX_3D_POOL:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%[[PADDED_INPUT_TENSOR:.*]], %{{.*}} : tensor<?x?x?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[OUTPUT_TENSOR:.*]] : tensor<?x?x?x?x?xf32>) {
// CHECK: %[[MAX_3D_POOL:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%[[PADDED_INPUT_TENSOR:.*]], %{{.*}} : tensor<?x?x?x?x?xf32>, tensor<8x8x8xf32>) outs(%[[OUTPUT_TENSOR:.*]] : tensor<?x?x?x?x?xf32>) {
// CHECK-NEXT: ^bb0(%[[CURRENT_VALUE:.*]]: f32, %[[KERNEL:.*]]: f32, %[[ACC_OUT:.*]]: f32):
// CHECK-NEXT: %[[MAXF:.*]] = arith.maximumf %[[CURRENT_VALUE:.*]], %[[ACC_OUT:.*]] : f32
// CHECK-NEXT: linalg.yield %[[MAXF:.*]] : f32
Expand Down
10 changes: 5 additions & 5 deletions test/Conversion/TorchToSCF/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
// CHECK-SAME: %[[VAL_0:.*]]: !torch.bool) -> !torch.int {
// CHECK: %[[VAL_1:.*]] = torch_c.to_i1 %[[VAL_0]]
// CHECK: %[[VAL_2:.*]] = torch.constant.int 2
// CHECK: %[[VAL_3:.*]] = torch_c.to_i64 %[[VAL_2]]
// CHECK: %[[VAL_3:.*]] = arith.constant 2 : i64
// CHECK: %[[VAL_4:.*]] = torch.constant.int 1
// CHECK: %[[VAL_5:.*]] = torch_c.to_i64 %[[VAL_4]]
// CHECK: %[[VAL_5:.*]] = arith.constant 1 : i64
// CHECK: %[[VAL_6:.*]] = scf.if %[[VAL_1]] -> (i64) {
// CHECK: scf.yield %[[VAL_3]] : i64
// CHECK: } else {
Expand All @@ -31,11 +31,11 @@ func.func @torch.prim.if(%arg0: !torch.bool) -> !torch.int {
// CHECK: %[[VAL_2:.*]] = torch_c.to_i1 %[[VAL_0]]
// CHECK: %[[VAL_3:.*]] = torch_c.to_i1 %[[VAL_1]]
// CHECK: %[[VAL_4:.*]] = torch.constant.int 2
// CHECK: %[[VAL_5:.*]] = torch_c.to_i64 %[[VAL_4]]
// CHECK: %[[VAL_5:.*]] = arith.constant 2 : i64
// CHECK: %[[VAL_6:.*]] = torch.constant.int 3
// CHECK: %[[VAL_7:.*]] = torch_c.to_i64 %[[VAL_6]]
// CHECK: %[[VAL_7:.*]] = arith.constant 3 : i64
// CHECK: %[[VAL_8:.*]] = torch.constant.int 4
// CHECK: %[[VAL_9:.*]] = torch_c.to_i64 %[[VAL_8]]
// CHECK: %[[VAL_9:.*]] = arith.constant 4 : i64
// CHECK: %[[VAL_10:.*]] = scf.if %[[VAL_2]] -> (i64) {
// CHECK: %[[VAL_11:.*]] = scf.if %[[VAL_3]] -> (i64) {
// CHECK: scf.yield %[[VAL_5]] : i64
Expand Down
Loading

0 comments on commit d212544

Please sign in to comment.