From b08d08682f2b3a32ba0b9c0130396cb9d684b135 Mon Sep 17 00:00:00 2001 From: Justin Ngo Date: Mon, 7 Oct 2024 10:28:26 -0700 Subject: [PATCH] [TOSA] Add legalization for fill, flip, and round (#3768) - Add Torch to TOSA lowering for aten.fill.Scalar/Tensor, aten.flip, and aten.round - Fix torchScalarToTosaTensor function to correctly convert Torch scalar input to TOSA tensor - Update xfail_sets.py with new e2e results - Update basic.mlir with LIT tests for new ops Change-Id: If1e42c2e582710dd8ad0465eed29806fbcdbde41 Signed-off-by: Justin Ngo --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 211 ++++++++++++++++++--- projects/pt1/e2e_testing/xfail_sets.py | 62 +++--- test/Conversion/TorchToTosa/basic.mlir | 81 ++++++++ 3 files changed, 298 insertions(+), 56 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 5664ebc7152db..77672181416f7 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -153,11 +153,17 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter, return rewriter.notifyMatchFailure(op, "Unable to extract the scalar constant"); + int64_t numElem = 1; + for (int64_t dim : dshape) + numElem *= dim; + if (isa(dtype)) { - tosaTensor = tosa::getConstTensor(rewriter, op, - (isFloat ? doubleValue : intValue), - dshape, dtype) - .value(); + tosaTensor = + tosa::getConstTensor( + rewriter, op, + SmallVector(numElem, (isFloat ? doubleValue : intValue)), + dshape, dtype) + .value(); } else if (auto intType = dyn_cast(dtype)) { auto w = intType.getWidth(); if (w != 1 && w != 32 && w != 64) @@ -173,8 +179,9 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter, } bool d = isFloat ? static_cast(doubleValue) : static_cast(intValue); - tosaTensor = - tosa::getConstTensor(rewriter, op, {d}, dshape).value(); + tosaTensor = tosa::getConstTensor( + rewriter, op, SmallVector(numElem, d), dshape) + .value(); } else if (w == 32) { if (!isInValidRange(isFloat, doubleValue, isInt, intValue)) { return rewriter.notifyMatchFailure( @@ -183,8 +190,9 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter, } int32_t d = isFloat ? static_cast(doubleValue) : static_cast(intValue); - tosaTensor = - tosa::getConstTensor(rewriter, op, {d}, dshape).value(); + tosaTensor = tosa::getConstTensor( + rewriter, op, SmallVector(numElem, d), dshape) + .value(); } else if (w == 64) { if (!isInValidRange(isFloat, doubleValue, isInt, intValue)) { return rewriter.notifyMatchFailure( @@ -192,8 +200,9 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter, "of destination type"); } int64_t d = (isFloat ? static_cast(doubleValue) : intValue); - tosaTensor = - tosa::getConstTensor(rewriter, op, {d}, dshape).value(); + tosaTensor = tosa::getConstTensor( + rewriter, op, SmallVector(numElem, d), dshape) + .value(); } } else { return rewriter.notifyMatchFailure(op, "Usupported element type"); @@ -5320,7 +5329,7 @@ class ConvertAtenConstPatternOp : public OpConversionPattern { }; template -class ConvertAtenFillScalarOp : public OpConversionPattern { +class ConvertAtenFillOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; @@ -5336,18 +5345,48 @@ class ConvertAtenFillScalarOp : public OpConversionPattern { op, "Only Tensor types with static shapes are currently supported"); Type outElemTy = outType.getElementType(); - if (!outElemTy.isIntOrFloat()) { + if (!outElemTy.isIntOrFloat()) return rewriter.notifyMatchFailure( op, "Only floating-point or integer datatype legalization supported"); + + Value fillValueTargetTensor; + if constexpr (std::is_same()) { + // Reshape value tensor to have same rank and shape as input + auto inputRank = + cast(adaptor.getSelf().getType()).getRank(); + + auto fillValue = adaptor.getValue(); + auto fillValueType = dyn_cast(fillValue.getType()); + if (!fillValueType) + return rewriter.notifyMatchFailure(op, "Fill value is not a tensor"); + auto fillValueElemTy = fillValueType.getElementType(); + + SmallVector fillValueMatchedInputRankShape(inputRank, 1); + + auto fillValueMatchedInputRankType = RankedTensorType::get( + makeShapeTorchCompatible(fillValueMatchedInputRankShape), + fillValueElemTy); + + auto fillValueMatchedInputRankTensor = rewriter.create( + op->getLoc(), fillValueMatchedInputRankType, fillValue, + rewriter.getDenseI64ArrayAttr(fillValueMatchedInputRankShape)); + + fillValueTargetTensor = rewriter.create( + op->getLoc(), + RankedTensorType::get(makeShapeTorchCompatible(outType.getShape()), + fillValueElemTy), + fillValueMatchedInputRankTensor.getResult(), + makeShapeTorchCompatible(outType.getShape())); + } else { + if (failed(torchScalarToTosaTensor( + rewriter, op, op.getValue(), fillValueTargetTensor, outElemTy, + makeShapeTorchCompatible(outType.getShape())))) + return rewriter.notifyMatchFailure( + op, "Fill value must be a scalar constant"); } - Value constOp; - if (failed(torchScalarToTosaTensor( - rewriter, op, op.getValue(), constOp, outElemTy, - makeShapeTorchCompatible(outType.getShape())))) - return rewriter.notifyMatchFailure( - op, "Supplied value must be a Scalar constant"); - rewriter.replaceOpWithNewOp(op, outType, constOp); + rewriter.replaceOpWithNewOp(op, outType, + fillValueTargetTensor); return success(); } @@ -5869,6 +5908,127 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +// Legalization for aten.flip +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenFlipOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + auto self = adaptor.getSelf(); + + auto selfTy = dyn_cast(self.getType()); + if (!selfTy) + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types are currently supported"); + + SmallVector dims; + if (!matchPattern(adaptor.getDims(), m_TorchListOfConstantInts(dims))) + return rewriter.notifyMatchFailure( + op, "Only constant dims are currently supported"); + + auto selfRank = selfTy.getRank(); + + auto resultTy = getTypeConverter()->convertType(op.getType()); + Value result = self; + + for (auto &dim : dims) { + dim = toPositiveDim(dim, selfRank); + if (!isValidDim(dim, selfRank)) + return rewriter.notifyMatchFailure(op, "Not all dims are valid"); + + result = rewriter.create(op->getLoc(), resultTy, result, + static_cast(dim)); + } + + rewriter.replaceOp(op, result); + return success(); +} + +// Legalization for aten.round: +// Rounds elements of input to the nearest integer. +// Implements "round half to even" to break ties when a number is equidistant +// from two integers. +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenRoundOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // To round to the nearest integer, we will consider the fractional part of + // the input element (= input element - integer part of element). If the + // fractional part is smaller than 0.5, round the number down. If the + // fractional part is 0.5, apply "round half to even" rule. If the fractional + // part is greater than 0.5, round up. + // + // if (frac < 0.5 || (frac == 0.5 && floor(input) % 2 == 0)): + // res = floor(input) + // else: + // res = ceil(input) + + auto self = adaptor.getSelf(); + + auto selfTy = dyn_cast(self.getType()); + if (!selfTy) + return rewriter.notifyMatchFailure(op, "Only tensor types supported"); + + auto resultTy = + cast(getTypeConverter()->convertType(op.getType())); + + auto boolTy = + RankedTensorType::get(resultTy.getShape(), rewriter.getIntegerType(1)); + + auto resultElemTy = resultTy.getElementType(); + + auto oneHalf = + tosa::getConstTensor(rewriter, op, 0.5, {}, resultElemTy).value(); + + auto two = + tosa::getConstTensor(rewriter, op, 2, {}, resultElemTy).value(); + + auto floorInput = + rewriter.create(op->getLoc(), resultTy, self); + + // input - floor(input) + auto fractionalPart = rewriter.create( + op->getLoc(), resultTy, self, floorInput.getResult()); + + auto ceilInput = rewriter.create(op->getLoc(), resultTy, self); + + auto floorInputDivByTwo = rewriter.create( + op->getLoc(), resultTy, floorInput.getResult(), oneHalf, /*shift=*/0); + + auto floorDivResult = rewriter.create( + op->getLoc(), resultTy, floorInputDivByTwo.getResult()); + + // (floor(input) // 2) * 2 + auto evenComparison = rewriter.create( + op->getLoc(), resultTy, floorDivResult.getResult(), two, /*shift=*/0); + + // floor(input) // 2) * 2 == input <=> floor(input) % 2 == 0 + auto floorInputEven = rewriter.create( + op->getLoc(), boolTy, floorInput.getResult(), evenComparison.getResult()); + + auto fracEqualOneHalf = rewriter.create( + op->getLoc(), boolTy, fractionalPart.getResult(), oneHalf); + + auto fracLtOneHalf = rewriter.create( + op->getLoc(), boolTy, oneHalf, fractionalPart.getResult()); + + // (frac == 0.5) && (floor(input) % 2 == 0) + auto fracEqualOneHalfCond = rewriter.create( + op->getLoc(), boolTy, fracEqualOneHalf.getResult(), + floorInputEven.getResult()); + + // (frac < 0.5) || ((frac == 0.5) && (floor(input) % 2 == 0)) + auto floorResultCond = rewriter.create( + op->getLoc(), boolTy, fracLtOneHalf.getResult(), + fracEqualOneHalfCond.getResult()); + + rewriter.replaceOpWithNewOp( + op, resultTy, floorResultCond.getResult(), floorInput.getResult(), + ceilInput.getResult()); + + return success(); +} + // Template to create supporting diagonal mask tensor for aten.diagonal template Value createDiagonalMask(PatternRewriter &rewriter, Operation *op, @@ -6052,6 +6212,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } + } // namespace // ----------------------------------------------------------------------------- @@ -6283,11 +6444,13 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0); #undef INSERT_CONSTANT_FILL_PATTERN -#define INSERT_FILL_SCALAR_PATTERN(AtenOp) \ +#define INSERT_FILL_PATTERN(AtenOp) \ target.addIllegalOp(); \ - patterns.add>(typeConverter, context); - INSERT_FILL_SCALAR_PATTERN(AtenFill_ScalarOp); -#undef INSERT_FILL_SCALAR_PATTERN + patterns.add>(typeConverter, context); + INSERT_FILL_PATTERN(AtenFill_ScalarOp); + INSERT_FILL_PATTERN(AtenFillScalarOp); + INSERT_FILL_PATTERN(AtenFillTensorOp); +#undef INSERT_FILL_PATTERN #define INSERT_MASKED_FILL_PATTERN(AtenOp) \ target.addIllegalOp(); \ @@ -6359,6 +6522,8 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { INSERT_ATENOP_PATTERN(AtenTrilOp); INSERT_ATENOP_PATTERN(AtenDiagonalOp); INSERT_ATENOP_PATTERN(AtenIndexSelectOp); + INSERT_ATENOP_PATTERN(AtenFlipOp); + INSERT_ATENOP_PATTERN(AtenRoundOp); #undef INSERT_ATENOP_PATTERN #define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \ diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index bd8d1994d9b46..09db1098e4b16 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1663,6 +1663,22 @@ # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "AtenRoundFloatHalfToEvenModule_basic", + "AtenRoundFloatModule_basic", + "FakeQuantizePerTensorAffineCachemaskModule_basic", + "FakeQuantizePerTensorAffineDynamicShapeModule_basic", + "FakeQuantizePerTensorAffineModule_basic", + "FakeQuantizePerTensorAffineRoundToEvenModule_basic", + "Fill_TensorFloat64WithFloat32Static_basic", + "Fill_TensorFloat64WithInt64Static_basic", + "FlipModuleStaticShape_basic", + "FlipModule_basic", + "FlipNegativeIndexModule_basic", + "Rot90BasicModule_basic", + "Rot90DynamicDimsModule_basic", + "Rot90MultipleRotationsModule_basic", + "Rot90NegativeEvenRotationsModule_basic", + "Rot90NegativeOddRotationsModule_basic", "AtenLinalgCrossBroadcast_basic", "AtenLinalgCrossCustomDim_basic", "AtenLinalgCrossFloat_basic", @@ -1819,7 +1835,6 @@ "ArangeStartOutModule_basic", "ArangeStartOutViewModule_basic", "ArangeStartStepIntModule_basic", - "ArangeZeroElementOutputModule_basic", "ArangeDtypeIntModule_basic", "ArangeFalsePinMemoryModule_basic", "ArangeFloatModule_basic", @@ -2120,7 +2135,6 @@ "NormScalarOptDimModule_basic", "NumToTensorFloatModule_basic", "NumToTensorIntModule_basic", - "NumpyTRank0Module_basic", "NumpyTRank1Module_basic", "NumpyTRank2Module_basic", "NumpyTRankNDynamicModule_basic", @@ -2132,7 +2146,6 @@ "OnesModuleInt_basic", "PadModule_basic", "PadWithNoneValModule_basic", - "Permute0RankModule_basic", "PermuteModule_basic", "PermuteNegativeIndexModule_basic", "PrimListUnpackNumMismatchModule_basic", @@ -2171,7 +2184,6 @@ "ScalarTensorInt64Module_basic", "SelectIntNegativeDimAndIndexStaticModule_basic", "SiluModule_basic", - "SliceOutOfUpperBoundIndexStaticModule_basic", "SliceStaticModule_basic", "SplitTensorGetItem_Module_basic", "SplitTensorLastSmallerModule_basic", @@ -3222,6 +3234,12 @@ } FX_IMPORTER_TOSA_XFAIL_SET = { + "ArangeZeroElementOutputModule_basic", + "NumpyTRank0Module_basic", + "Permute0RankModule_basic", + "SliceOutOfUpperBoundIndexModule_basic", + "SliceOutOfUpperBoundIndexStaticModule_basic", + "SliceStartEqEndModule_basic", "ChunkListUnpackDynamic_Module_basic", "ChunkListUnpackUnevenDynamic_Module_basic", "ChunkListUnpackUneven_Module_basic", @@ -3240,11 +3258,6 @@ "HstackBasicFloatModule_basic", "HstackBasicIntFloatModule_basic", "HstackBasicIntModule_basic", - "Rot90BasicModule_basic", - "Rot90DynamicDimsModule_basic", - "Rot90MultipleRotationsModule_basic", - "Rot90NegativeEvenRotationsModule_basic", - "Rot90NegativeOddRotationsModule_basic", "AtenIntMM_basic", "AtenKthvalueDynamicDimsModule_basic", "AtenKthvalueFloat64DynamicDimsModule_basic", @@ -3263,7 +3276,6 @@ "ElementwiseRreluEvalStaticModule_basic", "ElementwiseRreluTrainModule_basic", "ElementwiseRreluTrainStaticModule_basic", - "FakeQuantizePerTensorAffineCachemaskModule_basic", "IndexPutWithNoneAndBroadcastModule_basic", "MaskedScatterStaticBasic_basic", "MaxUnpool3dModulePad0_basic", @@ -3342,8 +3354,6 @@ "AtenMmQuint8_basic", "AtenRealView128Module_basic", "AtenRealView64Module_basic", - "AtenRoundFloatHalfToEvenModule_basic", - "AtenRoundFloatModule_basic", "AtenSubFloatModule_basic", "AtenTopKModule_basic", "AtenTopKSmallestModule_basic", @@ -3504,20 +3514,6 @@ "EqIntModule_basic", "ExpandModule_basic", "ExponentialModule_basic", - "FakeQuantizePerTensorAffineDynamicShapeModule_basic", - "FakeQuantizePerTensorAffineModule_basic", - "FakeQuantizePerTensorAffineRoundToEvenModule_basic", - "Fill_TensorFloat32WithFloat32_basic", - "Fill_TensorFloat32WithFloat64_basic", - "Fill_TensorFloat32WithInt64_basic", - "Fill_TensorFloat64WithFloat32Static_basic", - "Fill_TensorFloat64WithFloat32_basic", - "Fill_TensorFloat64WithFloat64_basic", - "Fill_TensorFloat64WithInt64Static_basic", - "Fill_TensorFloat64WithInt64_basic", - "FlipModuleStaticShape_basic", - "FlipModule_basic", - "FlipNegativeIndexModule_basic", "FloatImplicitModule_basic", "FullLikeModuleInt2D_basic", "FullLikeModuleInt3D_basic", @@ -3847,9 +3843,7 @@ "VarMeanUnbiasedModule_basic", "ViewCollapseDynamicWithAtenSizeIntModule_basic", "ViewSizeFromOtherTensor_basic", - "ZeroFloat32Module_basic", - "ZeroInt32Module_basic", - "ZeroInt64Module_basic", + "VisionTransformerModule_basic", "ZerosLikeModule_falsePinMemory", } @@ -3862,6 +3856,12 @@ } ONNX_TOSA_XFAIL_SET = { + "ArangeZeroElementOutputModule_basic", + "LinspaceEmptyModule_basic", + "RepeatInterleaveSelfIntNoDimModule_basic", + "SliceOutOfUpperBoundIndexStaticModule_basic", + "TrilIndicesAllZerosModule_basic", + "TriuIndicesAllZerosModule_basic", "ElementwiseCreateComplexModule_basic", "ReduceAllDimFloatModule_basic", "AdaptiveMaxPool1dDimOneStatic_basic", @@ -4026,8 +4026,6 @@ "AtenPolarDoubleModule_basic", "AtenRealView128Module_basic", "AtenRealView64Module_basic", - "AtenRoundFloatHalfToEvenModule_basic", - "AtenRoundFloatModule_basic", "AtenSubFloatModule_basic", "AtenTopKModule_basic", "AtenTopKSmallestModule_basic", @@ -4071,8 +4069,6 @@ "BucketizeTensorFloatModule_basic", "BucketizeTensorModule_basic", "BucketizeTensorOutInt32RightModule_basic", - "BucketizeTensorStaticFloatModule_basic", - "BucketizeTensorStaticModule_basic", "CeilFloatModule_basic", "ChunkListUnpackDynamic_Module_basic", "ChunkListUnpackUnevenDynamic_Module_basic", diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 6690868af5100..e569fed7fa937 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -1917,3 +1917,84 @@ func.func @torch.aten.index_select(%arg0: !torch.vtensor<[4,5,6],f32>, %arg1: !t %0 = torch.aten.index_select %arg0, %int2, %arg1 : !torch.vtensor<[4,5,6],f32>, !torch.int, !torch.vtensor<[2],si64> -> !torch.vtensor<[4,5,2],f32> return %0 : !torch.vtensor<[4,5,2],f32> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.fill.Scalar( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,12,128,128],f32>) -> !torch.vtensor<[1,12,128,128],f32> { +// CHECK: %[[VAL_1:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x12x128x128xf32>}> : () -> tensor<1x12x128x128xf32> +// CHECK: %[[VAL_3:.*]] = tosa.cast %[[VAL_2]] : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32> +// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32> +// CHECK: return %[[VAL_4]] : !torch.vtensor<[1,12,128,128],f32> +// CHECK: } +func.func @torch.aten.fill.Scalar(%arg0: !torch.vtensor<[1,12,128,128],f32>) -> !torch.vtensor<[1,12,128,128],f32> { + %int0 = torch.constant.int 0 + %0 = torch.aten.fill.Scalar %arg0, %int0 : !torch.vtensor<[1,12,128,128],f32>, !torch.int -> !torch.vtensor<[1,12,128,128],f32> + return %0 : !torch.vtensor<[1,12,128,128],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.fill.Tensor( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,12,128,128],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1],si32>) -> !torch.vtensor<[1,12,128,128],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1],si32> -> tensor<1xi32> +// CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<1xi32>) -> tensor<1x1x1x1xi32> +// CHECK: %[[VAL_4:.*]] = tosa.tile %[[VAL_3]] {multiples = array} : (tensor<1x1x1x1xi32>) -> tensor<1x12x128x128xi32> +// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_4]] : (tensor<1x12x128x128xi32>) -> tensor<1x12x128x128xf32> +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[1,12,128,128],f32> +// CHECK: } +func.func @torch.aten.fill.Tensor(%arg0: !torch.vtensor<[1,12,128,128],f32>, %arg1: !torch.vtensor<[1],si32>) -> !torch.vtensor<[1,12,128,128],f32> { + %0 = torch.aten.fill.Tensor %arg0, %arg1 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1],si32> -> !torch.vtensor<[1,12,128,128],f32> + return %0 : !torch.vtensor<[1,12,128,128],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.flip( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,5],f32> -> tensor<3x4x5xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_5:.*]] = tosa.reverse %[[VAL_1]] {axis = 1 : i32} : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_6:.*]] = tosa.reverse %[[VAL_5]] {axis = 2 : i32} : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<3x4x5xf32> -> !torch.vtensor<[3,4,5],f32> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[3,4,5],f32> +// CHECK: } +func.func @torch.aten.flip(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> { + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %0 = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.aten.flip %arg0, %0 : !torch.vtensor<[3,4,5],f32>, !torch.list -> !torch.vtensor<[3,4,5],f32> + return %1 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.round( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,5],f32> -> tensor<3x4x5xf32> +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor}> : () -> tensor +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.floor %[[VAL_1]] : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_5:.*]] = tosa.sub %[[VAL_1]], %[[VAL_4]] : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_6:.*]] = tosa.ceil %[[VAL_1]] : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_4]], %[[VAL_2]] {shift = 0 : i8} : (tensor<3x4x5xf32>, tensor) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_8:.*]] = tosa.floor %[[VAL_7]] : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_9:.*]] = tosa.mul %[[VAL_8]], %[[VAL_3]] {shift = 0 : i8} : (tensor<3x4x5xf32>, tensor) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_10:.*]] = tosa.equal %[[VAL_4]], %[[VAL_9]] : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xi1> +// CHECK: %[[VAL_11:.*]] = tosa.equal %[[VAL_5]], %[[VAL_2]] : (tensor<3x4x5xf32>, tensor) -> tensor<3x4x5xi1> +// CHECK: %[[VAL_12:.*]] = tosa.greater %[[VAL_2]], %[[VAL_5]] : (tensor, tensor<3x4x5xf32>) -> tensor<3x4x5xi1> +// CHECK: %[[VAL_13:.*]] = tosa.logical_and %[[VAL_11]], %[[VAL_10]] : (tensor<3x4x5xi1>, tensor<3x4x5xi1>) -> tensor<3x4x5xi1> +// CHECK: %[[VAL_14:.*]] = tosa.logical_or %[[VAL_12]], %[[VAL_13]] : (tensor<3x4x5xi1>, tensor<3x4x5xi1>) -> tensor<3x4x5xi1> +// CHECK: %[[VAL_15:.*]] = tosa.select %[[VAL_14]], %[[VAL_4]], %[[VAL_6]] : (tensor<3x4x5xi1>, tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_16:.*]] = torch_c.from_builtin_tensor %[[VAL_15]] : tensor<3x4x5xf32> -> !torch.vtensor<[3,4,5],f32> +// CHECK: return %[[VAL_16]] : !torch.vtensor<[3,4,5],f32> +// CHECK: } +func.func @torch.aten.round(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> { + %0 = torch.aten.round %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +}