diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 000000000000..3ab6783bdb61 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,15 @@ +# To get started with Dependabot version updates, you'll need to specify which +# package ecosystems to update and where the package manifests are located. +# Please see the documentation for all configuration options: +# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates + +version: 2 +updates: + - package-ecosystem: "gitsubmodule" + directory: "/" + allow: + - dependency-name: "externals/llvm-project" + schedule: + interval: "daily" + time: "06:00" + timezone: "Europe/Berlin" diff --git a/.github/workflows/approve_dependabot.yml b/.github/workflows/approve_dependabot.yml new file mode 100644 index 000000000000..ca3f6b6e9930 --- /dev/null +++ b/.github/workflows/approve_dependabot.yml @@ -0,0 +1,28 @@ +name: Dependabot auto-approve & auto-merge +on: pull_request + +permissions: + pull-requests: write + # Needed to enable auto-merge + contents: write + +jobs: + dependabot: + runs-on: ubuntu-latest + if: github.actor == 'dependabot[bot]' + steps: + - name: Dependabot metadata + id: metadata + uses: dependabot/fetch-metadata@v2 + with: + github-token: "${{ secrets.GITHUB_TOKEN }}" + - name: Approve a PR + run: gh pr review --approve "$PR_URL" + env: + PR_URL: ${{github.event.pull_request.html_url}} + GH_TOKEN: ${{secrets.GITHUB_TOKEN}} + - name: Enable auto-merge for Dependabot PRs + run: gh pr merge --auto --merge "$PR_URL" + env: + PR_URL: ${{github.event.pull_request.html_url}} + GH_TOKEN: ${{secrets.GITHUB_TOKEN}} diff --git a/externals/llvm-project b/externals/llvm-project index fa72e6813bb0..ecad3c58548d 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit fa72e6813bb05f5d13e7993f22c51cdb2ff8965a +Subproject commit ecad3c58548d08901ed340c34276d8534681ce93 diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 44ac95ce0429..1db603cc5aa0 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -682,7 +682,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { Type intType = IntegerType::get(context, 64); auto castIndexToInt = [&](Value v) { - return rewriter.create(loc, intType, v); + return rewriter.createOrFold(loc, intType, v); }; SmallVector paddingIntValues; diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index c83025e42e67..0e49eee04745 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -86,16 +86,10 @@ Value torch_to_linalg::getDynamicZeroPaddedTensor( pad < paddingIncludingUnchanged.end(); pad++) *pad = castIntToIndex(b, loc, *pad); - Type elementType = input.getType().cast().getElementType(); - // TODO: audit possibility of sparsity on this tensor - Type inputType = - RankedTensorType::get(makeShapeLLVMCompatible(llvm::ArrayRef( - SmallVector(inRank, kUnknownSize))), - elementType); - SmallVector paddingValues = getAsOpFoldResult(paddingIncludingUnchanged); - return b.create(loc, inputType, input, /*low=*/paddingValues, + + return b.create(loc, Type{}, input, /*low=*/paddingValues, /*high=*/paddingValues, pad); } @@ -107,25 +101,25 @@ Value torch_to_linalg::getOutputDimForConvOps(OpBuilder &b, Location loc, Value c1 = b.create(loc, b.getI64IntegerAttr(1)); Value c2 = b.create(loc, b.getI64IntegerAttr(2)); - Value doublePadding = b.create(loc, paddingInt, c2); + Value doublePadding = b.createOrFold(loc, paddingInt, c2); // in + 2 * padding - Value inAddDoublePadding = - b.create(loc, castIndexToInt64(b, loc, in), doublePadding); + Value inAddDoublePadding = b.createOrFold( + loc, castIndexToInt64(b, loc, in), doublePadding); // dilation * (kernelSize - 1) - Value kernelSizeSub1 = b.create(loc, kernelSizeInt, c1); + Value kernelSizeSub1 = b.createOrFold(loc, kernelSizeInt, c1); Value dilationTimesKernelSize = - b.create(loc, dilationInt, kernelSizeSub1); + b.createOrFold(loc, dilationInt, kernelSizeSub1); - Value temp = - b.create(loc, inAddDoublePadding, dilationTimesKernelSize); - Value dividend = b.create(loc, temp, c1); + Value temp = b.createOrFold(loc, inAddDoublePadding, + dilationTimesKernelSize); + Value dividend = b.createOrFold(loc, temp, c1); Value division; if (ceilMode) - division = b.create(loc, dividend, strideInt); + division = b.createOrFold(loc, dividend, strideInt); else - division = b.create(loc, dividend, strideInt); - Value out = b.create(loc, division, c1); + division = b.createOrFold(loc, dividend, strideInt); + Value out = b.createOrFold(loc, division, c1); return castIntToIndex(b, loc, out); } diff --git a/lib/Conversion/Utils/Utils.cpp b/lib/Conversion/Utils/Utils.cpp index 064215c51da0..4d42b5fea943 100644 --- a/lib/Conversion/Utils/Utils.cpp +++ b/lib/Conversion/Utils/Utils.cpp @@ -139,13 +139,13 @@ Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes, } Value castIntToIndex(OpBuilder &b, Location loc, Value v) { - assert(v.getType().isa() && "must be called with integer type"); - return b.create(loc, b.getIndexType(), v); + assert(isa(v.getType()) && "must be called with integer type"); + return b.createOrFold(loc, b.getIndexType(), v); } Value castIndexToInt64(OpBuilder &b, Location loc, Value idx) { - assert(idx.getType().isa() && "must be called with integer type"); - return b.create(loc, b.getI64Type(), idx); + assert(isa(idx.getType()) && "must be called with integer type"); + return b.createOrFold(loc, b.getI64Type(), idx); } SmallVector diff --git a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp index 1cda55724ee3..3bba2be4d5f2 100644 --- a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp +++ b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp @@ -94,7 +94,7 @@ static void setupTorchIntToI64Conversion(ConversionTarget &target, if (!inputs[0].getType().isa()) return std::nullopt; assert(inputs.size() == 1); - return builder.create(loc, inputs[0]).getResult(); + return builder.createOrFold(loc, inputs[0]); }); auto sourceMaterialization = [](OpBuilder &builder, Torch::IntType type, ValueRange inputs, Location loc) -> Value { diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index cc88728fa642..38832b8684b6 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -13,8 +13,6 @@ from torch_mlir_e2e_test.test_suite import COMMON_TORCH_MLIR_LOWERING_XFAILS from torch_mlir._version import torch_version_for_comparison, version -print(f"TORCH_VERSION_FOR_COMPARISON =", torch_version_for_comparison()) - LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | { "Conv1dNoPaddingGroupModule_basic", "RepeatInterleaveStaticModule_basic", @@ -352,11 +350,6 @@ "InterpolateDynamicModule_scales_recompute_bilinear", } -if torch_version_for_comparison() <= version.parse("2.2.0"): - TORCHDYNAMO_XFAIL_SET |= { - 'OneHotModule_basic', - } - TORCHDYNAMO_CRASHING_SET = { # No upstream decompositions. # %6:4 = torch.operator "aten._embedding_bag_forward_only"(%1, %3, %5, %false, %int0, %false, %none, %false, %int-1) : (!torch.tensor<*,f32>, !torch.tensor<*,si64>, !torch.tensor<*,si64>, !torch.bool, !torch.int, !torch.bool, !torch.none, !torch.bool, !torch.int) -> (!torch.tensor, !torch.tensor, !torch.tensor, !torch.tensor) diff --git a/projects/pt1/python/torch_mlir/dynamo.py b/projects/pt1/python/torch_mlir/dynamo.py index 7fc887d56bc4..7622ec013659 100644 --- a/projects/pt1/python/torch_mlir/dynamo.py +++ b/projects/pt1/python/torch_mlir/dynamo.py @@ -50,6 +50,7 @@ def _get_decomposition_table(): # (the upstream decomposition we use here does), even though we have # support for aten.native_batch_norm_backward. aten._native_batch_norm_legit_functional, + aten._native_batch_norm_legit_no_training, aten.native_group_norm, aten.split.Tensor, aten.split_with_sizes, @@ -67,9 +68,6 @@ def _get_decomposition_table(): aten.cumsum, aten.index_select, ] - # TODO: enable test once 2.1.0 is stable - if torch_version_for_comparison() >= version.parse("2.1.0.dev"): - decomp_list += [aten._native_batch_norm_legit_no_training] return get_decompositions(decomp_list) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py index 6f492a1eff5c..c03fd95505a6 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py @@ -18,14 +18,6 @@ "ElementwiseToDtypeI64ToUI8Module_basic", } -# TODO: Delete once torch 2.1.0 is released -if torch_version_for_comparison() < version.parse("2.1.0.dev"): - COMMON_TORCH_MLIR_LOWERING_XFAILS.update({ - "ScaledDotProductAttentionDifferentModule_basic", - "ScaledDotProductAttentionSameModule_basic" - }) - - def register_all_tests(): """Registers all the built-in E2E tests that Torch-MLIR provides.""" # Side-effecting import statements. diff --git a/stable-requirements.txt b/stable-requirements.txt index 1641e0540671..27d0c30d7a91 100644 --- a/stable-requirements.txt +++ b/stable-requirements.txt @@ -1,3 +1,3 @@ --index-url https://download.pytorch.org/whl/cpu -torch==2.1.2+cpu -torchvision==0.16.2+cpu +torch==2.3.1+cpu +torchvision==0.18.1+cpu diff --git a/test/Conversion/TorchToLinalg/elementwise.mlir b/test/Conversion/TorchToLinalg/elementwise.mlir index bed94f98da2b..2ed7906cc56c 100644 --- a/test/Conversion/TorchToLinalg/elementwise.mlir +++ b/test/Conversion/TorchToLinalg/elementwise.mlir @@ -67,7 +67,7 @@ func.func @elementwise$ternary(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?],f32>, // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],f32> { // CHECK: %[[C1:.*]] = torch.constant.int 1 -// CHECK: %[[BUILTIN_C1:.*]] = torch_c.to_i64 %[[C1]] +// CHECK: %[[BUILTIN_C1:.*]] = arith.constant 1 : i64 // CHECK: linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>] // CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32, %{{.*}}: f32): // CHECK: %[[ALPHA:.*]] = arith.sitofp %[[BUILTIN_C1]] : i64 to f32 diff --git a/test/Conversion/TorchToLinalg/pooling.mlir b/test/Conversion/TorchToLinalg/pooling.mlir index 8a359ed5627d..4c3a279de440 100644 --- a/test/Conversion/TorchToLinalg/pooling.mlir +++ b/test/Conversion/TorchToLinalg/pooling.mlir @@ -11,15 +11,15 @@ func.func @forward_max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vt %int7 = torch.constant.int 7 %int8 = torch.constant.int 8 %false = torch.constant.bool false - // CHECK: %[[C1:.*]] = torch_c.to_i64 %int1 - // CHECK: %[[C2:.*]] = torch_c.to_i64 %int2 + // CHECK: %[[C1:.*]] = arith.constant 1 : i64 + // CHECK: %[[C2:.*]] = arith.constant 2 : i64 // CHECK: %[[NEUTRAL:.*]] = arith.constant 0xFF800000 : f32 // CHECK: %[[PADDED:.*]] = tensor.pad %{{.*}} low[0, 0, 5, 6] high[0, 0, 5, 6] // CHECK: %[[OUT:.*]] = linalg.fill ins(%[[NEUTRAL]] : f32) outs(%{{.*}} : tensor) -> tensor - // CHECK: %[[T1:.*]] = arith.index_cast %[[C1]] : i64 to index - // CHECK: %[[T2:.*]] = arith.index_cast %[[C2]] : i64 to index - // CHECK: %[[INIT:.*]] = tensor.empty(%[[T1]], %[[T2]]) : tensor - // CHECK: linalg.pooling_nchw_max {dilations = dense<[7, 8]> : vector<2xi64>, strides = dense<[3, 4]> : vector<2xi64>} ins(%[[PADDED]], %[[INIT]] : tensor, tensor) outs(%[[OUT]] : tensor) -> tensor + // CHECK: %[[T1:.*]] = arith.constant 1 : index + // CHECK: %[[T2:.*]] = arith.constant 2 : index + // CHECK: %[[INIT:.*]] = tensor.empty() : tensor<1x2xf32> + // CHECK: linalg.pooling_nchw_max {dilations = dense<[7, 8]> : vector<2xi64>, strides = dense<[3, 4]> : vector<2xi64>} ins(%[[PADDED]], %[[INIT]] : tensor, tensor<1x2xf32>) outs(%[[OUT]] : tensor) -> tensor %kernel_size = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list %stride = torch.prim.ListConstruct %int3, %int4 : (!torch.int, !torch.int) -> !torch.list %padding = torch.prim.ListConstruct %int5, %int6 : (!torch.int, !torch.int) -> !torch.list @@ -66,7 +66,7 @@ func.func @forward_max_pool3d(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch. // CHECK: } : tensor to tensor // CHECK: %[[OUTPUT_TENSOR:.*]] = linalg.fill ins(%[[MIN_VALUE:.*]] : f32) outs(%{{.*}} : tensor) -> tensor - // CHECK: %[[MAX_3D_POOL:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%[[PADDED_INPUT_TENSOR:.*]], %{{.*}} : tensor, tensor) outs(%[[OUTPUT_TENSOR:.*]] : tensor) { + // CHECK: %[[MAX_3D_POOL:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%[[PADDED_INPUT_TENSOR:.*]], %{{.*}} : tensor, tensor<8x8x8xf32>) outs(%[[OUTPUT_TENSOR:.*]] : tensor) { // CHECK-NEXT: ^bb0(%[[CURRENT_VALUE:.*]]: f32, %[[KERNEL:.*]]: f32, %[[ACC_OUT:.*]]: f32): // CHECK-NEXT: %[[MAXF:.*]] = arith.maximumf %[[CURRENT_VALUE:.*]], %[[ACC_OUT:.*]] : f32 // CHECK-NEXT: linalg.yield %[[MAXF:.*]] : f32 diff --git a/test/Conversion/TorchToSCF/basic.mlir b/test/Conversion/TorchToSCF/basic.mlir index fadac3b4f97d..65ce89f494d1 100644 --- a/test/Conversion/TorchToSCF/basic.mlir +++ b/test/Conversion/TorchToSCF/basic.mlir @@ -4,9 +4,9 @@ // CHECK-SAME: %[[VAL_0:.*]]: !torch.bool) -> !torch.int { // CHECK: %[[VAL_1:.*]] = torch_c.to_i1 %[[VAL_0]] // CHECK: %[[VAL_2:.*]] = torch.constant.int 2 -// CHECK: %[[VAL_3:.*]] = torch_c.to_i64 %[[VAL_2]] +// CHECK: %[[VAL_3:.*]] = arith.constant 2 : i64 // CHECK: %[[VAL_4:.*]] = torch.constant.int 1 -// CHECK: %[[VAL_5:.*]] = torch_c.to_i64 %[[VAL_4]] +// CHECK: %[[VAL_5:.*]] = arith.constant 1 : i64 // CHECK: %[[VAL_6:.*]] = scf.if %[[VAL_1]] -> (i64) { // CHECK: scf.yield %[[VAL_3]] : i64 // CHECK: } else { @@ -31,11 +31,11 @@ func.func @torch.prim.if(%arg0: !torch.bool) -> !torch.int { // CHECK: %[[VAL_2:.*]] = torch_c.to_i1 %[[VAL_0]] // CHECK: %[[VAL_3:.*]] = torch_c.to_i1 %[[VAL_1]] // CHECK: %[[VAL_4:.*]] = torch.constant.int 2 -// CHECK: %[[VAL_5:.*]] = torch_c.to_i64 %[[VAL_4]] +// CHECK: %[[VAL_5:.*]] = arith.constant 2 : i64 // CHECK: %[[VAL_6:.*]] = torch.constant.int 3 -// CHECK: %[[VAL_7:.*]] = torch_c.to_i64 %[[VAL_6]] +// CHECK: %[[VAL_7:.*]] = arith.constant 3 : i64 // CHECK: %[[VAL_8:.*]] = torch.constant.int 4 -// CHECK: %[[VAL_9:.*]] = torch_c.to_i64 %[[VAL_8]] +// CHECK: %[[VAL_9:.*]] = arith.constant 4 : i64 // CHECK: %[[VAL_10:.*]] = scf.if %[[VAL_2]] -> (i64) { // CHECK: %[[VAL_11:.*]] = scf.if %[[VAL_3]] -> (i64) { // CHECK: scf.yield %[[VAL_5]] : i64 diff --git a/test/Conversion/TorchToStablehlo/elementwise.mlir b/test/Conversion/TorchToStablehlo/elementwise.mlir index 367985233577..3ff1d095c532 100644 --- a/test/Conversion/TorchToStablehlo/elementwise.mlir +++ b/test/Conversion/TorchToStablehlo/elementwise.mlir @@ -104,8 +104,7 @@ func.func @torch.aten.sigmoid$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch. // CHECK-LABEL: func.func @torch.aten.addscalar$basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[INT9:.*]] = torch.constant.int 9 -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] +// CHECK: %[[T1:.*]] = arith.constant 9 : i64 // CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> // CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> @@ -125,10 +124,8 @@ func.func @torch.aten.addscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torc // CHECK-LABEL: func.func @torch.aten.addscalar$alpha( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[INT9:.*]] = torch.constant.int 9 -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]] +// CHECK: %[[T1:.*]] = arith.constant 9 : i64 +// CHECK: %[[T2:.*]] = arith.constant 2 : i64 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> // CHECK: %[[T3:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> // CHECK: %[[T4:.*]] = stablehlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor @@ -166,10 +163,9 @@ func.func @torch.aten.addtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // CHECK-LABEL: func.func @torch.aten.addtensor$alpha( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]] +// CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T2:.*]] = arith.constant 2 : i64 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64> // CHECK: %[[T3:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> // CHECK: %[[T4:.*]] = stablehlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor @@ -205,8 +201,7 @@ func.func @torch.aten.addtensor$promote(%arg0: !torch.vtensor<[?,?],si32>, %arg1 // CHECK-LABEL: func.func @torch.aten.subscalar$basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[INT9:.*]] = torch.constant.int 9 -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] +// CHECK: %[[T1:.*]] = arith.constant 9 : i64 // CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> // CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> @@ -226,8 +221,7 @@ func.func @torch.aten.subscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torc // CHECK-LABEL: func.func @torch.aten.rsubscalar$basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[INT9:.*]] = torch.constant.int 9 -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] +// CHECK: %[[T1:.*]] = arith.constant 9 : i64 // CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> // CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> @@ -247,10 +241,8 @@ func.func @torch.aten.rsubscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !tor // CHECK-LABEL: func.func @torch.aten.subscalar$alpha( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[INT9:.*]] = torch.constant.int 9 -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]] +// CHECK: %[[T1:.*]] = arith.constant 9 : i64 +// CHECK: %[[T2:.*]] = arith.constant 2 : i64 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> // CHECK: %[[T3:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> // CHECK: %[[T4:.*]] = stablehlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor @@ -288,10 +280,9 @@ func.func @torch.aten.subtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // CHECK-LABEL: func.func @torch.aten.subtensor$alpha( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]] +// CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T2:.*]] = arith.constant 2 : i64 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64> // CHECK: %[[T3:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> // CHECK: %[[T4:.*]] = stablehlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor @@ -327,8 +318,7 @@ func.func @torch.aten.subtensor$promote(%arg0: !torch.vtensor<[?,?],si32>, %arg1 // CHECK-LABEL: func.func @torch.aten.mulscalar$basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[INT9:.*]] = torch.constant.int 9 -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] +// CHECK: %[[T1:.*]] = arith.constant 9 : i64 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> // CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> // CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor @@ -360,8 +350,7 @@ func.func @torch.aten.multensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // CHECK-LABEL: func.func @torch.aten.divscalar$basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[INT9:.*]] = torch.constant.int 9 -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] +// CHECK: %[[T1:.*]] = arith.constant 9 : i64 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> // CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> // CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor @@ -393,8 +382,7 @@ func.func @torch.aten.divtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // CHECK-LABEL: func.func @torch.aten.gt.scalar( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[INT3:.*]] = torch.constant.int 3 -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT3]] +// CHECK: %[[T1:.*]] = arith.constant 3 : i64 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> // CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> // CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor @@ -636,4 +624,4 @@ func.func @torch.aten.div.Tensor_mode$floor(%arg0: !torch.vtensor<[?,?,?,?],f32> func.func @torch.aten.abs(%arg0: !torch.vtensor<[15,15],si64>) -> !torch.vtensor<[15,15],si64>{ %0 = torch.aten.abs %arg0 : !torch.vtensor<[15,15],si64> -> !torch.vtensor<[15,15],si64> return %0 : !torch.vtensor<[15,15],si64> -} \ No newline at end of file +} diff --git a/test/Conversion/TorchToStablehlo/linear.mlir b/test/Conversion/TorchToStablehlo/linear.mlir index 7f253a98df04..a72ca1c206d7 100644 --- a/test/Conversion/TorchToStablehlo/linear.mlir +++ b/test/Conversion/TorchToStablehlo/linear.mlir @@ -278,7 +278,7 @@ func.func @torch.aten.mm$proj(%arg0: !torch.vtensor<[?,256],f32>) -> !torch.vten // CHECK: %[[T_5:.*]] = torch.constant.int 1 // CHECK: %[[T_6:.*]] = torch.constant.int 4 // CHECK: %[[T_7:.*]] = torch.constant.int 3 -// CHECK: %[[T_8:.*]] = torch_c.to_i64 %[[T_7]] +// CHECK: %[[T_8:.*]] = arith.constant 3 : i64 // CHECK: %[[T_9:.*]] = torch.prim.ListConstruct %[[T_4]], %[[T_5]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_10:.*]] = torch.prim.ListConstruct %[[T_6]], %[[T_4]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_11:.*]] = torch.prim.ListConstruct %[[T_7]], %[[T_5]] : (!torch.int, !torch.int) -> !torch.list @@ -314,8 +314,7 @@ func.func @torch.aten.convolution(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: ! // CHECK: %int2 = torch.constant.int 2 // CHECK: %int1 = torch.constant.int 1 // CHECK: %int4 = torch.constant.int 4 -// CHECK: %int3 = torch.constant.int 3 -// CHECK: %[[T_3:.*]] = torch_c.to_i64 %int3 +// CHECK: %[[T_3:.*]] = arith.constant 3 : i64 // CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int4, %int2 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_6:.*]] = torch.prim.ListConstruct %int3, %int1 : (!torch.int, !torch.int) -> !torch.list @@ -357,7 +356,7 @@ func.func @torch.aten.convolution$bias(%arg0: !torch.vtensor<[?,?,?,?],f32>, %ar // CHECK: %none = torch.constant.none // CHECK: %int0 = torch.constant.int 0 // CHECK: %int1 = torch.constant.int 1 -// CHECK: %[[T_2:.*]] = torch_c.to_i64 %int1 +// CHECK: %[[T_2:.*]] = arith.constant 1 : i64 // CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_5:.*]] = stablehlo.transpose %[[T_1]], dims = [2, 3, 1, 0] : (tensor<2x4x3x3xf32>) -> tensor<3x3x4x2xf32> @@ -388,7 +387,7 @@ func.func @torch.aten.convolution$transposed_basic(%arg0: !torch.vtensor<[1,2,7, // CHECK: %none = torch.constant.none // CHECK: %int0 = torch.constant.int 0 // CHECK: %int1 = torch.constant.int 1 -// CHECK: %[[T_2:.*]] = torch_c.to_i64 %int1 +// CHECK: %[[T_2:.*]] = arith.constant 1 : i64 // CHECK: %int2 = torch.constant.int 2 // CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list @@ -423,7 +422,7 @@ func.func @torch.aten.convolution$transposed_stride(%arg0: !torch.vtensor<[1,2,7 // CHECK: %none = torch.constant.none // CHECK: %int0 = torch.constant.int 0 // CHECK: %int1 = torch.constant.int 1 -// CHECK: %[[T_2:.*]] = torch_c.to_i64 %int1 +// CHECK: %[[T_2:.*]] = arith.constant 1 : i64 // CHECK: %int2 = torch.constant.int 2 // CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list @@ -459,12 +458,12 @@ func.func @torch.aten.convolution$transposed_outputpadding(%arg0: !torch.vtensor // CHECK: %int0 = torch.constant.int 0 // CHECK: %int1 = torch.constant.int 1 // CHECK: %int2 = torch.constant.int 2 -// CHECK: %[[T_2:.*]] = torch_c.to_i64 %int2 +// CHECK: %[[T_2:.*]] = arith.constant 2 : i64 // CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_6:.*]] = stablehlo.transpose %[[T_1]], dims = [2, 3, 1, 0] : (tensor<2x2x3x3xf32>) -> tensor<3x3x2x2xf32> -// CHECK: %[[T_7:.*]] = stablehlo.reverse %6, dims = [0, 1] : tensor<3x3x2x2xf32> +// CHECK: %[[T_7:.*]] = stablehlo.reverse %[[T_6]], dims = [0, 1] : tensor<3x3x2x2xf32> // CHECK: %c0 = arith.constant 0 : index // CHECK: %dim = tensor.dim %[[T_7]], %c0 : tensor<3x3x2x2xf32> // CHECK: %[[T_8:.*]] = arith.index_cast %dim : index to i64 @@ -477,14 +476,14 @@ func.func @torch.aten.convolution$transposed_outputpadding(%arg0: !torch.vtensor // CHECK: %c3 = arith.constant 3 : index // CHECK: %dim_2 = tensor.dim %[[T_7]], %c3 : tensor<3x3x2x2xf32> // CHECK: %[[T_11:.*]] = arith.index_cast %dim_2 : index to i64 -// CHECK: %c2_i64 = arith.constant 2 : i64 -// CHECK: %[[T_12:.*]] = arith.divsi %[[T_11]], %c2_i64 : i64 -// CHECK: %[[T_13:.*]] = arith.muli %[[T_10]], %c2_i64 : i64 -// CHECK: %from_elements = tensor.from_elements %[[T_8]], %[[T_9]], %[[T_10]], %c2_i64, %[[T_12]] : tensor<5xi64> +// CHECK: %[[C2:.*]] = arith.constant 2 : i64 +// CHECK: %[[T_12:.*]] = arith.divsi %[[T_11]], %[[C2]] : i64 +// CHECK: %[[T_13:.*]] = arith.muli %[[T_10]], %[[C2]] : i64 +// CHECK: %from_elements = tensor.from_elements %[[T_8]], %[[T_9]], %[[T_10]], %[[C2]], %[[T_12]] : tensor<5xi64> // CHECK: %[[T_14:.*]] = stablehlo.dynamic_reshape %[[T_7]], %from_elements : (tensor<3x3x2x2xf32>, tensor<5xi64>) -> tensor<3x3x2x2x1xf32> // CHECK: %[[T_15:.*]] = stablehlo.transpose %[[T_14]], dims = [0, 1, 3, 2, 4] : (tensor<3x3x2x2x1xf32>) -> tensor<3x3x2x2x1xf32> -// CHECK: %from_elements_3 = tensor.from_elements %[[T_8]], %[[T_9]], %[[T_13]], %[[T_12]] : tensor<4xi64> -// CHECK: %[[T_16:.*]] = stablehlo.dynamic_reshape %[[T_15]], %from_elements_3 : (tensor<3x3x2x2x1xf32>, tensor<4xi64>) -> tensor<3x3x4x1xf32> +// CHECK: %[[from_elements_3:.*]] = tensor.from_elements %[[T_8]], %[[T_9]], %[[T_13]], %[[T_12]] : tensor<4xi64> +// CHECK: %[[T_16:.*]] = stablehlo.dynamic_reshape %[[T_15]], %[[from_elements_3]] : (tensor<3x3x2x2x1xf32>, tensor<4xi64>) -> tensor<3x3x4x1xf32> // CHECK: %[[T_17:.*]] = stablehlo.convolution(%[[T_0]], %[[T_16]]) // CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[0, 1, o, i]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 2 : i64} : (tensor<1x2x7x7xf32>, tensor<3x3x4x1xf32>) -> tensor<1x4x15x15xf32> // CHECK: %[[T_18:.*]] = torch_c.from_builtin_tensor %[[T_17]] : tensor<1x4x15x15xf32> -> !torch.vtensor<[1,4,15,15],f32> diff --git a/test/Conversion/TorchToStablehlo/view_like.mlir b/test/Conversion/TorchToStablehlo/view_like.mlir index 206084873c81..5de40484f401 100644 --- a/test/Conversion/TorchToStablehlo/view_like.mlir +++ b/test/Conversion/TorchToStablehlo/view_like.mlir @@ -3,12 +3,9 @@ // CHECK-LABEL: func.func @torch.aten.slice.strided$slice_like( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?],f32> -> tensor -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT0]] -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]] -// CHECK: %[[INT10:.*]] = torch.constant.int 10 -// CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT10]] +// CHECK: %[[T1:.*]] = arith.constant 0 : i64 +// CHECK: %[[T2:.*]] = arith.constant 2 : i64 +// CHECK: %[[T3:.*]] = arith.constant 10 : i64 // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor // CHECK: %[[T4:.*]] = arith.index_cast %[[DIM]] : index to i64 @@ -42,7 +39,7 @@ // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T10]], %[[C0_I64_5]], %[[C0_I64_5]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[T21]], %[[T18]], %[[T19]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_7:.*]] = tensor.from_elements %[[T2]], %[[C1_I64]], %[[C1_I64]] : tensor<3xi64> -// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor +// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS_6]], %[[FROM_ELEMENTS_7]] : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor // CHECK: %[[T23:.*]] = torch_c.from_builtin_tensor %[[T22]] : tensor -> !torch.vtensor<[?,?,?],f32> // CHECK: return %[[T23]] : !torch.vtensor<[?,?,?],f32> func.func @torch.aten.slice.strided$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { @@ -58,12 +55,9 @@ func.func @torch.aten.slice.strided$slice_like(%arg0: !torch.vtensor<[?,?,?],f32 // CHECK-LABEL: func.func @torch.aten.slice.strided.static$slice_like( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[2,65,256],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,65,256],f32> -> tensor<4x65x256xf32> -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT0]] -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]] -// CHECK: %[[INT9223372036854775807:.*]] = torch.constant.int 9223372036854775807 -// CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT9223372036854775807]] +// CHECK: %[[T1:.*]] = arith.constant 0 : i64 +// CHECK: %[[T2:.*]] = arith.constant 2 : i64 +// CHECK: %[[T3:.*]] = arith.constant 9223372036854775807 : i64 // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<4x65x256xf32> // CHECK: %[[T4:.*]] = arith.index_cast %[[DIM]] : index to i64 @@ -97,7 +91,7 @@ func.func @torch.aten.slice.strided$slice_like(%arg0: !torch.vtensor<[?,?,?],f32 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T10]], %[[C0_I64_5]], %[[C0_I64_5]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[T21]], %[[T18]], %[[T19]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_7:.*]] = tensor.from_elements %[[T2]], %[[C1_I64]], %[[C1_I64]] : tensor<3xi64> -// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<2x65x256xf32> +// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS_6]], %[[FROM_ELEMENTS_7]] : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<2x65x256xf32> // CHECK: %[[T23:.*]] = torch_c.from_builtin_tensor %[[T22]] : tensor<2x65x256xf32> -> !torch.vtensor<[2,65,256],f32> // CHECK: return %[[T23]] : !torch.vtensor<[2,65,256],f32> func.func @torch.aten.slice.strided.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[2,65,256],f32> { @@ -113,12 +107,9 @@ func.func @torch.aten.slice.strided.static$slice_like(%arg0: !torch.vtensor<[4,6 // CHECK-LABEL: func.func @torch.aten.slice.last$slice_like( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,1,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?],f32> -> tensor -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT0]] -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT1]] -// CHECK: %[[INT:.*]]-1 = torch.constant.int -1 -// CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT]]-1 +// CHECK: %[[T1:.*]] = arith.constant 0 : i64 +// CHECK: %[[T2:.*]] = arith.constant 1 : i64 +// CHECK: %[[T3:.*]] = arith.constant -1 : i64 // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor // CHECK: %[[T4:.*]] = arith.index_cast %[[DIM]] : index to i64 @@ -152,7 +143,7 @@ func.func @torch.aten.slice.strided.static$slice_like(%arg0: !torch.vtensor<[4,6 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C0_I64_5]], %[[T10]], %[[C0_I64_5]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[T17]], %[[T21]], %[[T19]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_7:.*]] = tensor.from_elements %[[C1_I64]], %[[T2]], %[[C1_I64]] : tensor<3xi64> -// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor +// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS_6]], %[[FROM_ELEMENTS_7]] : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor // CHECK: %[[T23:.*]] = torch_c.from_builtin_tensor %[[T22]] : tensor -> !torch.vtensor<[?,1,?],f32> // CHECK: return %[[T23]] : !torch.vtensor<[?,1,?],f32> func.func @torch.aten.slice.last$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,1,?],f32> { @@ -168,12 +159,9 @@ func.func @torch.aten.slice.last$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) // CHECK-LABEL: func.func @torch.aten.slice.last.static$slice_like( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,1,256],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,65,256],f32> -> tensor<4x65x256xf32> -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT0]] -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT1]] -// CHECK: %[[INT:.*]]-1 = torch.constant.int -1 -// CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT]]-1 +// CHECK: %[[T1:.*]] = arith.constant 0 : i64 +// CHECK: %[[T2:.*]] = arith.constant 1 : i64 +// CHECK: %[[T3:.*]] = arith.constant -1 : i64 // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<4x65x256xf32> // CHECK: %[[T4:.*]] = arith.index_cast %[[DIM]] : index to i64 @@ -207,7 +195,7 @@ func.func @torch.aten.slice.last$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C0_I64_5]], %[[T10]], %[[C0_I64_5]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[T17]], %[[T21]], %[[T19]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_7:.*]] = tensor.from_elements %[[C1_I64]], %[[T2]], %[[C1_I64]] : tensor<3xi64> -// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x1x256xf32> +// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS_6]], %[[FROM_ELEMENTS_7]] : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x1x256xf32> // CHECK: %[[T23:.*]] = torch_c.from_builtin_tensor %[[T22]] : tensor<4x1x256xf32> -> !torch.vtensor<[4,1,256],f32> // CHECK: return %[[T23]] : !torch.vtensor<[4,1,256],f32> func.func @torch.aten.slice.last.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,1,256],f32> { @@ -224,8 +212,7 @@ func.func @torch.aten.slice.last.static$slice_like(%arg0: !torch.vtensor<[4,65,2 // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?],f32> -> tensor // CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT2]] +// CHECK: %[[T1:.*]] = arith.constant 2 : i64 // CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor @@ -247,7 +234,7 @@ func.func @torch.aten.slice.last.static$slice_like(%arg0: !torch.vtensor<[4,65,2 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C0_I64_4]], %[[C0_I64]], %[[C0_I64_4]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_5:.*]] = tensor.from_elements %[[T3]], %[[T7]], %[[T5]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[C1_I64]] : tensor<3xi64> -// CHECK: %[[T8:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_5, %[[FROM_ELEMENTS]]_6 : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor +// CHECK: %[[T8:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_5, %[[FROM_ELEMENTS_6]] : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor // CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor -> !torch.vtensor<[?,?,?],f32> // CHECK: return %[[T9]] : !torch.vtensor<[?,?,?],f32> func.func @torch.aten.slice.none$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { @@ -264,8 +251,7 @@ func.func @torch.aten.slice.none$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,33,256],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,65,256],f32> -> tensor<4x65x256xf32> // CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT2]] +// CHECK: %[[T1:.*]] = arith.constant 2 : i64 // CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<4x65x256xf32> @@ -287,7 +273,7 @@ func.func @torch.aten.slice.none$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C0_I64_4]], %[[C0_I64]], %[[C0_I64_4]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_5:.*]] = tensor.from_elements %[[T3]], %[[T7]], %[[T5]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[C1_I64]] : tensor<3xi64> -// CHECK: %[[T8:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_5, %[[FROM_ELEMENTS]]_6 : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x33x256xf32> +// CHECK: %[[T8:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_5, %[[FROM_ELEMENTS_6]] : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x33x256xf32> // CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<4x33x256xf32> -> !torch.vtensor<[4,33,256],f32> // CHECK: return %[[T9]] : !torch.vtensor<[4,33,256],f32> func.func @torch.aten.slice.none.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,33,256],f32> {