diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index ed2726fa1b42..779bd6249283 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -2766,21 +2766,23 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "Only ranked tensor types with static shapes are currently supported"); - SmallVector dimListInt; - if (!matchPattern(adaptor.getDims(), m_TorchListOfConstantInts(dimListInt))) + SmallVector dimListInt64; + if (!matchPattern(adaptor.getDims(), m_TorchListOfConstantInts(dimListInt64))) return rewriter.notifyMatchFailure( op, "Only constant dimensions are currently supported"); + SmallVector dimListInt32; + copy(dimListInt64, std::back_inserter(dimListInt32)); int64_t selfRank = selfType.getRank(); // TODO: If this is already verified on the op then we can drop checking here. - for (auto &d : dimListInt) { + for (auto &d : dimListInt32) { d = toPositiveDim(d, selfRank); if (!isValidDim(d, selfRank)) return rewriter.notifyMatchFailure(op, "Not all dims are valid"); } - auto transposeDimsConst = mlir::tosa::getConstTensor( - rewriter, op.getOperation(), dimListInt, {selfRank}); + auto transposeDimsConst = mlir::tosa::getConstTensor( + rewriter, op.getOperation(), dimListInt32, {selfRank}); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(), diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index eb852f6c76b2..7d046177fc14 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -797,8 +797,8 @@ func.func @torch.aten.ne.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // CHECK: %[[VAL_3:.*]] = torch.constant.int 2 // CHECK: %[[VAL_4:.*]] = torch.constant.int 0 // CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_3]], %[[VAL_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi64>}> : () -> tensor<3xi64> -// CHECK: %[[VAL_7:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_6]] : (tensor<3x4x2xf32>, tensor<3xi64>) -> tensor<3x2x4xf32> +// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK: %[[VAL_7:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_6]] : (tensor<3x4x2xf32>, tensor<3xi32>) -> tensor<3x2x4xf32> // CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<3x2x4xf32> -> !torch.vtensor<[3,2,4],f32> // CHECK: return %[[VAL_8]] : !torch.vtensor<[3,2,4],f32> // CHECK: }