From 4dd213b04223f2b49418205739702d80ff2c4a9b Mon Sep 17 00:00:00 2001 From: Justin Ngo Date: Wed, 30 Oct 2024 16:26:10 -0700 Subject: [PATCH] [TOSA] Expand Torch to TOSA legalization coverage (#3827) - Add/Extend Torch to TOSA legalization for the following ops: + Add aten.threshold_backward + Fix aten.threshold + Re-implement aten.broadcast_to using tosa.reshape and tosa.tile + Add support for rank 0 index for aten.index_select + Fix aten.index_put.hacked_twin + Add aten.uniform + Add aten.logical_and - Update xfail_sets.py with new e2e results - Add LIT tests to basic.mlir for newly added ops Change-Id: I8910564a049d18293284fe2e55e82bc1d2cf10e3 Signed-off-by: Justin Ngo --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 409 ++++++++++++++------- projects/pt1/e2e_testing/xfail_sets.py | 128 +++---- test/Conversion/TorchToTosa/basic.mlir | 86 ++++- 3 files changed, 399 insertions(+), 224 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index b6dbdc2c7b8c..ce8351ea9920 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -25,6 +25,7 @@ #include "llvm/ADT/TypeSwitch.h" #include #include +#include using namespace mlir; using namespace mlir::torch; @@ -125,15 +126,14 @@ template static bool isInValidRange(bool isFloat, const double &doubleValue, bool isInt, const int64_t &intValue) { if (isFloat) { - // Do a round-trip check here instead of numeric limits due to - // compiler warnings around double <-> int conversion. - return (doubleValue == static_cast(static_cast(doubleValue))); - } else { - assert(isInt); + return (doubleValue >= + static_cast(std::numeric_limits::min())) && + (doubleValue <= static_cast(std::numeric_limits::max())); + } else if (isInt) { return (intValue >= static_cast(std::numeric_limits::min())) && (intValue <= static_cast(std::numeric_limits::max())); } - return true; + return false; } // FIXME: This will eventually go into a Tosa*Utils file. @@ -165,13 +165,13 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter, dshape, dtype) .value(); } else if (auto intType = dyn_cast(dtype)) { - auto w = intType.getWidth(); - if (w != 1 && w != 32 && w != 64) + auto width = intType.getWidth(); + if (width != 1 && width != 8 && width != 32 && width != 64) return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { diag << "Unsupported integer type: " << intType; }); - if (w == 1) { + if (width == 1) { if (!isInValidRange(isFloat, doubleValue, isInt, intValue)) { return rewriter.notifyMatchFailure( op, "Supplied value of scalar constant exceeds limits " @@ -182,7 +182,18 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter, tosaTensor = tosa::getConstTensor( rewriter, op, SmallVector(numElem, d), dshape) .value(); - } else if (w == 32) { + } else if (width == 8) { + if (!isInValidRange(isFloat, doubleValue, isInt, intValue)) { + return rewriter.notifyMatchFailure( + op, "Supplied value of scalar constant exceeds limits " + "of destination type"); + } + int8_t d = isFloat ? static_cast(doubleValue) + : static_cast(intValue); + tosaTensor = tosa::getConstTensor( + rewriter, op, SmallVector(numElem, d), dshape) + .value(); + } else if (width == 32) { if (!isInValidRange(isFloat, doubleValue, isInt, intValue)) { return rewriter.notifyMatchFailure( op, "Supplied value of scalar constant exceeds limits " @@ -193,7 +204,7 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter, tosaTensor = tosa::getConstTensor( rewriter, op, SmallVector(numElem, d), dshape) .value(); - } else if (w == 64) { + } else if (width == 64) { if (!isInValidRange(isFloat, doubleValue, isInt, intValue)) { return rewriter.notifyMatchFailure( op, "Supplied value of scalar constant exceeds limits " @@ -919,13 +930,17 @@ class ConvertAtenMultipleDimsReductionOp ConversionPatternRewriter &rewriter, ElementsAttr &reduceDimsAttr, bool &keepDims) const override { - SmallVector reduceDims; - if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(reduceDims))) - return rewriter.notifyMatchFailure(op, - "non-const dim parameter unsupported"); - int64_t N = reduceDims.size(); int64_t inputRank = cast(adaptor.getSelf().getType()).getRank(); + + SmallVector reduceDims; + // If dim list is none, all dimensions are reduced + if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(reduceDims))) { + for (int64_t i = 0; i < inputRank; i++) + reduceDims.push_back(i); + } + + int64_t N = reduceDims.size(); for (unsigned i = 0; i < N; i++) { reduceDims[i] = toPositiveDim(reduceDims[i], inputRank); if (!isValidDim(reduceDims[i], inputRank)) @@ -2895,9 +2910,10 @@ template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenThresholdOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + auto self = adaptor.getSelf(); // Not a tensor type. - auto selfType = dyn_cast(adaptor.getSelf().getType()); + auto selfType = dyn_cast(self.getType()); if (!selfType) return rewriter.notifyMatchFailure( op, "Only tensor types are currently supported"); @@ -2907,12 +2923,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "Only floating-point or integer datatype legalization supported"); - // Integer types with width > 32 are not supported - auto selfIntType = dyn_cast(selfElemTy); - if (selfIntType && selfIntType.getWidth() > 32) { - return rewriter.notifyMatchFailure( - op, "Integer types with width greater than 32 are not supported"); - } + auto outType = + dyn_cast(getTypeConverter()->convertType(op.getType())); + auto outElemTy = outType.getElementType(); SmallVector constTypeShape(selfType.getRank(), 1); Value threshold, value; @@ -2922,21 +2935,16 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "Only scalar constant is supported for threshold"); if (failed(torchScalarToTosaTensor(rewriter, op, op.getValue(), value, - selfElemTy, constTypeShape))) + outElemTy, constTypeShape))) return rewriter.notifyMatchFailure( op, "Only scalar constant is supported for value"); - // Threshold only clamps the upper values. tosa::ClampOp has the same - // value for both threshold and clamped value so cannot be used. - auto outType = getTypeConverter()->convertType(op.getType()); - auto cmpOp = rewriter.create( op.getLoc(), RankedTensorType::get(selfType.getShape(), rewriter.getIntegerType(1)), - adaptor.getSelf(), threshold); + self, threshold); - rewriter.replaceOpWithNewOp(op, outType, cmpOp, - adaptor.getSelf(), value); + rewriter.replaceOpWithNewOp(op, outType, cmpOp, self, value); return success(); } @@ -3660,8 +3668,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( AtenBroadcastToOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + auto self = adaptor.getSelf(); // Not a tensor type. - auto selfType = dyn_cast(adaptor.getSelf().getType()); + auto selfType = dyn_cast(self.getType()); if (!selfType || !selfType.hasStaticShape()) return rewriter.notifyMatchFailure( op, "Only tensor types with static shape are supported"); @@ -3675,19 +3684,43 @@ LogicalResult ConvertAtenOp::matchAndRewrite( SmallVector resultShape; if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(resultShape))) return rewriter.notifyMatchFailure(op, - "size must consist of Scalar constants"); + "Size must consist of Scalar constants"); + + int64_t inputRank = selfType.getRank(); + int64_t outputRank = resultShape.size(); + if (inputRank > outputRank) + return rewriter.notifyMatchFailure( + op, "Input tensor rank cannot be greater than output tensor rank"); + // Get the result type auto resultType = getTypeConverter()->convertType(op.getType()); SmallVector inputShape( makeShapeTorchCompatible(selfType.getShape())); + + // If input rank is smaller than output rank, we reshape the input tensor to + // be the same rank as the output tensor by prepending 1s to the input shape + SmallVector targetInputShape; + for (int64_t i = 0; i < outputRank - inputRank; i++) + targetInputShape.push_back(1); + targetInputShape.append(inputShape); + // Result dimension -1 means not changing the size of that dimension. // Adjust it by assigning its inputShape. - for (auto shape : llvm::enumerate(makeShapeTorchCompatible(inputShape))) { + for (auto shape : + llvm::enumerate(makeShapeTorchCompatible(targetInputShape))) { auto index = shape.index(); if (resultShape[index] == -1) resultShape[index] = shape.value(); } + + for (int64_t i = 0; i < outputRank; i++) { + if (targetInputShape[i] != resultShape[i] && targetInputShape[i] != 1) + return rewriter.notifyMatchFailure( + op, "Input and result shapes should be equal at each dimension or " + "input shape should be 1"); + } + // Check for identity case i.e, for ex: [a, b, c] -> [a, b, c]. If this is // true then we can replace the op result with the input operand directly. if (llvm::equal(inputShape, resultShape)) { @@ -3695,52 +3728,40 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // since the input and result are of same shape. op.replaceAllUsesWith(op.getSelf()); rewriter.eraseOp(op); - return success(); - } else if (selfType.hasRank() && - (selfType.getRank() == (int64_t)resultShape.size() || - selfType.getRank() == 0)) { - // Right now to support limited cases where input and result shape are not - // equal, we can put a constraint that either the input should be of rank - // 0 or the rank of input tensor and result should be equal. And then we - // can check for broadcasting compatibility for the latter case. For - // broadcasting compatibility, either the shape of input and result should - // be equal at each dimenion or one of them should be 1. - if (selfType.getRank() != 0) { - for (unsigned i = 0; i < inputShape.size(); i++) { - if (inputShape[i] != resultShape[i] && inputShape[i] != 1 && - resultShape[i] != 1) { - return rewriter.notifyMatchFailure( - op, "unimplemented: either the shape of input and result should " - "be equal at each dimenion or one of them should be 1."); - } + } else { + // By using reshape and tile ops, support for input rank smaller than result + // rank is allowed. If the rank is smaller, we reshape the input to be the + // same rank as the result, then use tile to expand it. The way it was + // handled before involves adding the input tensor to a const zero tensor of + // output shape to utilize the innate broadcast feature of the TOSA add op. + // That poses the danger of sign bit flips for denormalized values. + // Basically, this approach to broadcast_to legalization allows for more + // flexibility in rank differences and also offers more safety. + Value reshapedInput = self; + if (!llvm::equal(inputShape, targetInputShape)) + reshapedInput = rewriter.create( + op->getLoc(), + RankedTensorType::get(makeShapeTorchCompatible(targetInputShape), + selfElemTy), + self, rewriter.getDenseI64ArrayAttr(targetInputShape)); + + SmallVector tileOpShape; + for (int64_t i = 0; i < outputRank; i++) { + if (targetInputShape[i] == 1) { + tileOpShape.push_back(resultShape[i]); + } else { + tileOpShape.push_back(1); } } - // If the above condition hold true then we can directly create a const - // zero tensor of shape same as the result shape. - SmallVector zeroTensorShape{resultShape}; + auto result = rewriter.create( + op->getLoc(), resultType, reshapedInput, + rewriter.getDenseI64ArrayAttr(tileOpShape)); - // create the 0 constant tensor - int64_t totalNumElements = 1; - for (auto dimSize : zeroTensorShape) { - totalNumElements = dimSize * totalNumElements; - } - // There is some danger here. For edge cases in floating point, x + 0 != x. - // The cases are denormalized values, which may get flushed, and -0 + 0 = - // +0. (sign bit flips). These are probably acceptable in the short term, - // but we should put a comment acknowledging the danger, as there isn't an - // op that avoids the denorm flushing. - Value zeroTensor = - tosa::getZerosLikeTensor(rewriter, op, resultType).value(); - - // Use add broadcast - rewriter.replaceOpWithNewOp(op, resultType, adaptor.getSelf(), - zeroTensor); - return success(); + rewriter.replaceOp(op, {result.getResult()}); } - return rewriter.notifyMatchFailure( - op, - "unimplemented: broadcasts other than same rank or zero ranked tensor."); + + return success(); } template <> @@ -3843,6 +3864,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto index = adaptor.getIndex(); auto indexType = dyn_cast(index.getType()); + auto indexShape = indexType.getShape(); if (!indexType) return rewriter.notifyMatchFailure( @@ -3851,9 +3873,13 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto inputShape = inputType.getShape(); int inputRank = inputType.getRank(); - if (indexType.getRank() == 0) - return rewriter.notifyMatchFailure( - op, "Rank 0 index tensor is currently not supported"); + if (indexType.getRank() == 0) { + indexShape = makeShapeTorchCompatible({1}); + index = rewriter.create( + op->getLoc(), + RankedTensorType::get(indexShape, indexType.getElementType()), index, + rewriter.getDenseI64ArrayAttr(indexShape)); + } // Dynamic shape check if (!inputType.hasStaticShape() || !indexType.hasStaticShape()) @@ -3865,9 +3891,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (indexType.getElementType() != rewriter.getIntegerType(32)) { index = rewriter.create( op->getLoc(), - RankedTensorType::get(indexType.getShape(), - rewriter.getIntegerType(32)), - index); + RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), index); } // Get positive dim @@ -3896,7 +3920,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( SmallVector indicesInputRankShape; for (int64_t i = 0; i < inputRank; i++) { if (i == dim) { - indicesInputRankShape.push_back(indexType.getShape()[0]); + indicesInputRankShape.push_back(indexShape[0]); } else { indicesInputRankShape.push_back(1); } @@ -3952,49 +3976,41 @@ template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenIndexPutHackedTwinOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - // a = torch.tensor([[0, 1, 2, 3]]) - // a[..., 1:] = torch.tensor([4, 5, 6]) - // = a[..., 1:4] = torch.tensor([4, 5, 6]) - // = a[[0, 0, 0], [1, 2, 3]] = torch.tensor([4, 5, 6]) # tensor([[0, 4, 5, - // 6]]) = torch.ops.aten.index_put(torch.tensor([[0, 1, 2, 3]]), # input - // (torch.tensor([0, 0, 0]), torch.tensor([1, 2, - // 3])), # indicies torch.tensor([4, 5, 6])) # - // value - // = torch.ops.aten.index_put(torch.tensor([[0, 1, 2, 3]]), # input - // (None, torch.tensor([1, 2, 3]),),# indicies - // torch.tensor([4, 5, 6])) # value - // Not a tensor type. auto input = adaptor.getSelf(); - auto selfType = dyn_cast(adaptor.getSelf().getType()); + auto selfType = dyn_cast(input.getType()); if (!selfType) return rewriter.notifyMatchFailure( op, "Only tensor types input are currently supported"); auto fillValues = adaptor.getValues(); - auto valuesType = dyn_cast(adaptor.getValues().getType()); + auto valuesType = dyn_cast(fillValues.getType()); if (!valuesType) return rewriter.notifyMatchFailure( op, "Only tensor types input are currently supported"); // Deal with torch.prim.ListConstruct of non const value to get the index + // Index_put-like ops are now decomposed to aten.index_put.hacked_twin with + // stricter semantics, i.e., no None index in indices argument. auto tensorList = op.getIndices(); SmallVector tensorsTorchType; if (!getListConstructElements(tensorList, tensorsTorchType)) - return op.emitError( - "unimplemented: the tensor list is not from list construct"); + return op.emitError("Tensor list is not from list construct"); auto indexTensors = getTypeConvertedValues( rewriter, op->getLoc(), getTypeConverter(), tensorsTorchType); auto outType = getTypeConverter()->convertType(op.getType()); - // convert list of indices with none into indices tensor without none - // indexTensors (none,[1,2,3]) -> ([0,0,0],[1,2,3]) - // ([[0],[0],[0]],[[1],[2],[3]])-> [[0,1],[0,2], [0,3]] - if (indexTensors.size() <= 1) { + bool accumulate{false}; + if (!matchPattern(op.getAccumulate(), m_TorchConstantBool(&accumulate))) return rewriter.notifyMatchFailure( - op, "Only support indexput with multiple index."); - } + op, "Accumulate is not a constant bool value"); + + // No support for accumulate mode yet + if (accumulate) + return rewriter.notifyMatchFailure( + op, "Accumulate mode is not currently supported"); + SmallVector indicesTfConcatTensors; SmallVector indexesRank; SmallVector> indexesShape; @@ -4002,28 +4018,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // concat index tensor into to indices tensor for concat for (size_t i = 0; i < indexTensors.size(); i++) { auto index = indexTensors[i]; - auto indexTorch = tensorsTorchType[i]; - // TODO add support for none index other than i==0, like (index0, None) - // (None, index1) - if (i == 0 && isa(indexTorch.getType())) { - // convert None to [0,0,0] - auto indexNext = indexTensors[i + 1]; - auto indexNextTorch = tensorsTorchType[i + 1]; - if (isa(indexNextTorch.getType())) { - return rewriter.notifyMatchFailure( - op, "Multiple None index is not support for now."); - } - auto indexNextType = dyn_cast(indexNext.getType()); - auto indexNextShape = indexNextType.getShape(); - - int64_t size = 1; - for (auto s : indexNextShape) - size *= s; - SmallVector values(size, i); - index = - tosa::getConstTensor(rewriter, op, values, indexNextShape) - .value(); - } auto indexType = dyn_cast(index.getType()); auto indexShape = indexType.getShape(); @@ -4031,20 +4025,19 @@ LogicalResult ConvertAtenOp::matchAndRewrite( indexesRank.push_back(indexType.getRank()); // index i64 to i32 for tosa compatible - if (indexType.getElementType() != rewriter.getIntegerType(32)) { + if (indexType.getElementType() != rewriter.getIntegerType(32)) index = rewriter.create( op->getLoc(), RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), index); - } // Expand last dim of index to tf indices [3] -> [3,1] // convert [0,0,0] to [[0],[0],[0]] SmallVector indiceShapeOneDim; - for (auto shape : indexShape) { + for (auto shape : indexShape) indiceShapeOneDim.push_back(shape); - } indiceShapeOneDim.push_back(1); + auto indicesTfOneDim = tosa::CreateOpAndInfer( rewriter, op->getLoc(), RankedTensorType::get(indiceShapeOneDim, rewriter.getIntegerType(32)), @@ -4061,7 +4054,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( for (auto indexShapeOneDim : indexesShape) { if (!llvm::equal(indexesShape[0], indexShapeOneDim)) { return rewriter.notifyMatchFailure( - op, "unimplemented: Only support multi indexes with same shape"); + op, "Only support indices with same shape"); } } @@ -4075,19 +4068,16 @@ LogicalResult ConvertAtenOp::matchAndRewrite( GetTypeFromTensorShape(indicesShapeConcat, rewriter.getIntegerType(32)), indicesTfConcatTensors, lastDim); - if (!indicesTf) { - return rewriter.notifyMatchFailure(op, - "Convert TorchIndex To TfIndices fail."); - } - // do the tf scatterNd algorithm with tf style indices as input, algorithm - // mostly take from convertGatherNdOp. + if (!indicesTf) + return rewriter.notifyMatchFailure( + op, "Convert PyTorch index to TensorFlow indices failed"); + auto result = tosa::convertScatterNdOp(rewriter, op, outType, input, indicesTf.getResult(), fillValues); - if (!result) { - return rewriter.notifyMatchFailure( - op, "Convert ScatterNdOp fail for index tensor."); - } + if (!result) + return rewriter.notifyMatchFailure(op, "Convert ScatterNdOp failed"); + rewriter.replaceOp(op, {result.value()}); return success(); @@ -6632,6 +6622,140 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +// Legalization for aten.uniform +// Since TOSA hasn't got a built-in random generator yet, we will use +// std::uniform_real_distribution with the std::default_random_engine from C++ +// library +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenUniformOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto self = adaptor.getSelf(); + + // Not a tensor type + auto selfType = dyn_cast(self.getType()); + if (!selfType) + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); + + auto selfShape = selfType.getShape(); + + auto generator = adaptor.getGenerator(); + if (!isa(generator.getType())) + return rewriter.notifyMatchFailure(op, + "Custom generators are not supported"); + + double fromDouble{0.0}, toDouble{1.0}; + auto isFloat = + matchPattern(op.getFrom(), m_TorchConstantFloat(&fromDouble)) && + matchPattern(op.getTo(), m_TorchConstantFloat(&toDouble)); + + int64_t fromInt{0}, toInt{1}; + auto isInt = matchPattern(op.getFrom(), m_TorchConstantInt(&fromInt)) && + matchPattern(op.getTo(), m_TorchConstantInt(&toInt)); + + if (!isFloat && !isInt) + return rewriter.notifyMatchFailure( + op, "From and To values are not constant values"); + + int64_t numElem = 1; + for (int64_t i = 0; i < selfType.getRank(); i++) + numElem *= selfShape[i]; + + auto resultType = + dyn_cast(typeConverter->convertType(op.getType())); + + std::default_random_engine gen; + + auto from = isFloat ? fromDouble : fromInt; + auto to = isFloat ? toDouble : toInt; + + std::uniform_real_distribution uniformDist(from, to); + SmallVector uniformVec; + + for (int64_t i = 0; i < numElem; i++) + uniformVec.push_back(uniformDist(gen)); + + auto result = tosa::getConstTensor(rewriter, op, uniformVec, selfShape, + selfType.getElementType()) + .value(); + + result = tosa::promoteType(rewriter, result, resultType); + + rewriter.replaceOp(op, {result}); + + return success(); +} + +// Legalization for aten.threshold_backward +// result = self <= threshold ? 0 : grad +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenThresholdBackwardOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto self = adaptor.getSelf(); + + // Not a tensor type + auto selfType = dyn_cast(self.getType()); + if (!selfType) + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); + auto selfElemTy = selfType.getElementType(); + + auto selfShape = selfType.getShape(); + + auto resultType = + dyn_cast(typeConverter->convertType(op.getType())); + auto resultElemTy = resultType.getElementType(); + + Value threshold; + if (failed(torchScalarToTosaTensor(rewriter, op, op.getThreshold(), threshold, + selfElemTy, selfShape))) + return rewriter.notifyMatchFailure(op, + "Threshold must be a constant scalar"); + + auto grad = adaptor.getGradOutput(); + + // Not a tensor type + auto gradType = dyn_cast(grad.getType()); + if (!gradType) + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); + + Value zero = + TypeSwitch(resultElemTy) + .Case([&](auto) { + return tosa::getConstTensor(rewriter, op, 0, {}, + resultElemTy) + .value(); + }) + .Case([&](auto intType) { + switch (intType.getWidth()) { + case 1: + return tosa::getConstTensor(rewriter, op, 0, {}).value(); + case 8: + return tosa::getConstTensor(rewriter, op, 0, {}).value(); + case 32: + return tosa::getConstTensor(rewriter, op, 0, {}).value(); + case 64: + return tosa::getConstTensor(rewriter, op, 0, {}).value(); + } + llvm_unreachable("Invalid integer width"); + }); + + // Check: input <= threshold + auto cond = rewriter.create( + op->getLoc(), RankedTensorType::get(selfShape, rewriter.getI1Type()), + threshold, self); + + self = tosa::promoteType(rewriter, self, resultType); + grad = tosa::promoteType(rewriter, grad, resultType); + + auto result = rewriter.create(op->getLoc(), resultType, + cond.getResult(), zero, grad); + + rewriter.replaceOp(op, {result.getResult()}); + + return success(); +} + } // namespace // ----------------------------------------------------------------------------- @@ -6705,6 +6829,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { INSERT_BINARY_PATTERN(AtenMinimumOp, tosa::MinimumOp) INSERT_BINARY_PATTERN(AtenLogicalOrOp, tosa::LogicalOrOp) INSERT_BINARY_PATTERN(AtenLogicalXorOp, tosa::LogicalXorOp) + INSERT_BINARY_PATTERN(AtenLogicalAndOp, tosa::LogicalAndOp) INSERT_BINARY_PATTERN(AtenBitwiseLeftShiftTensorOp, tosa::LogicalLeftShiftOp) INSERT_BINARY_PATTERN(AtenBitwiseRightShiftTensorOp, @@ -6947,6 +7072,8 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { INSERT_ATENOP_PATTERN(AtenScatterSrcOp); INSERT_ATENOP_PATTERN(AtenSliceScatterOp); INSERT_ATENOP_PATTERN(AtenDiagEmbedOp); + INSERT_ATENOP_PATTERN(AtenUniformOp); + INSERT_ATENOP_PATTERN(AtenThresholdBackwardOp); #undef INSERT_ATENOP_PATTERN #define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \ diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 3881aa145d1c..854c2d8710c6 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1707,9 +1707,17 @@ "ScatterSrcStaticModule_basic", # Runtime op verification: Out of bounds access "ReduceAllDimEmpty_basic", + # SmallVector unable to grow for ThresholdBackward1d + "ThresholdBackward1dFloatModule_basic", + "ThresholdBackward1dIntModule_basic", + "ThresholdBackward1dMixedModule_basic", } FX_IMPORTER_TOSA_CRASHING_SET = { + "GridSamplerBasic1_basic", + "GridSamplerBasic2_basic", + "GridSamplerBasic3_basic", + "GridSamplerBasic4_basic", "ScatterSrcModule_basic", "ScatterSrcStaticModule_basic", "HBC_basic", @@ -1727,6 +1735,25 @@ # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "CosineSimilarityStaticBroadcastModule_basic", + "DropoutTrainStaticShapeModule_basic", + "ElementwiseAtenLogicalAndOpModule_basic", + "ElementwiseAtenLogicalAndOpPromoteBroadcastModule_basic", + "ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic", + "ElementwiseBitwiseAndScalarInt8Module_basic", + "ElementwiseRreluTrainStaticModule_basic", + "IndexSelectRank0IdxModule_basic", + "MseLossSumReductionWithDifferentElemTypeModule_basic", + "NativeDropoutTrainStaticShapeModule_basic", + "RandIntDtypeModule_basic", + "RandIntLowDtypeModule_basic", + "RandModule_basic", + "ReduceL3NormAllDimsModule_basic", + "ReduceL3NormKeepDimModule_basic", + "SliceCopy_Module_basic", + "Threshold1dIntModule_basic", + "Threshold2dIntModule_basic", + "Threshold3dIntModule_basic", "EmptyModule_contiguous", "EmptyModule_defaultDtype", "EmptyModule_falsePinMemory", @@ -2296,8 +2323,6 @@ "TensorIntModule_basic", "TensorLiteralModule_basic", "TensorOpaqueLiteralModule_basic", - "TensorsConcatNegativeDimStaticModule_basic", - "TensorsConcatStaticModule_basic", "TestF16Return_basic", "TestMultipleTensorReturn_basic", "Threshold1dFloatModule_basic", @@ -2363,7 +2388,6 @@ "LinspaceModule_basic", "LinspaceOneSizeModule_basic", "LinspaceTwoSizeModule_basic", - "TorchPrimLoopForLikeTensorArgModule_basic", "RenormModuleFloat32NegativeDim_basic", "RenormModuleFloat32_basic", "IndexTensorStaticContiguousWithNoneModule_basic", @@ -2468,7 +2492,6 @@ "SplitWithSizesListUnpackModule_basic", # Dynamic shape, has extra unsupported broadcast ops "Matmul_3d", - "MatmulStaticBroadcast_basic", # Unimplemented operator 'aten._index_put_impl_.hacked_twin' "IndexPutImpl1DFloatNonAccumulateModule_basic", "IndexPutImpl1DIntNonAccumulateModule_basic", @@ -2487,7 +2510,6 @@ "ElementwiseLogSigmoidModule_basic", # failed to legalize operation 'torch.aten.rrelu_with_noise' "ElementwiseRreluEvalModule_basic", - "ElementwiseRreluEvalStaticModule_basic", # incompatible return type failure for tosa.concat. "HstackBasicComplexModule_basic", "HstackBasicFloatModule_basic", @@ -3329,6 +3351,14 @@ } FX_IMPORTER_TOSA_XFAIL_SET = { + "AdaptiveMaxPool1dDimOneStatic_basic", + "ElementwiseRreluWithNoiseTrainModule_basic", + "ElementwiseRreluWithNoiseTrainStaticModule_basic", + "MaxPool3dEmptyStrideStaticModule_basic", + "MaxPool3dLargeDatadModule_basic", + "MaxPool3dModuleRandomSimple_basic", + "MaxPool3dModule_basic", + "MaxPool3dStaticModule_basic", "ViewDtypeStaticModule_basic", "Unfold_Module_Dynamic_basic", "Unfold_Module_Rank_4", @@ -3474,7 +3504,6 @@ "BoolIntFalseModule_basic", "BoolIntTrueModule_basic", "BroadcastDynamicDimModule_basic", - "BroadcastToModule_basic", "CeilFloatModule_basic", "CollapseAllDimensionsModule_basic", "CollapseFullDynamicModule_basic", @@ -3509,7 +3538,6 @@ "ConvolutionModule2DTransposeStrided_basic", "ConvolutionModule2DTranspose_basic", "CopyWithDifferentDTypesModule_basic", - "CosineSimilarityStaticBroadcastModule_basic", "CumsumInputDtypeInt32Module_basic", "CumsumModule_basic", "CumsumStaticModule_basic", @@ -3524,8 +3552,6 @@ "DeterminantModule_F32", "DivFloatModule_basic", "DivIntModule_basic", - "DropoutTrainModule_basic", - "DropoutTrainStaticShapeModule_basic", "ElementwiseAcosIntModule_basic", "ElementwiseAcosModule_basic", "ElementwiseAcoshIntModule_basic", @@ -3545,11 +3571,7 @@ "ElementwiseAtanTensorIntModule_basic", "ElementwiseAtanhIntModule_basic", "ElementwiseAtanhModule_basic", - "ElementwiseAtenLogicalAndOpModule_basic", - "ElementwiseAtenLogicalAndOpPromoteBroadcastModule_basic", - "ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic", "ElementwiseAtenLogicalNotOpPromoteModule_basic", - "ElementwiseBitwiseAndScalarInt8Module_basic", "ElementwiseClampMinTensorFloatModule_basic", "ElementwiseClampMinTensorIntModule_basic", "ElementwiseClampTensorFloatModule_basic", @@ -3590,12 +3612,9 @@ "ElementwiseUnaryIntModule_basic", "ElementwiseWhereScalarOtherStaticModule_basic", "EqIntModule_basic", - "ExpandModule_basic", - "ExponentialModule_basic", "FloatImplicitModule_basic", "FullLikeModuleInt2D_basic", "FullLikeModuleInt3D_basic", - "FullModuleInt2D_basic", "GeFloatIntModule_basic", "GeFloatModule_basic", "GeIntModule_basic", @@ -3606,42 +3625,25 @@ "GtFloatIntModule_basic", "GtIntModule_basic", "IndexPut1DFloatAccumulateModule_basic", - "IndexPut1DFloatNonAccumulateModule_basic", "IndexPut1DIntAccumulateModule_basic", - "IndexPut1DIntNonAccumulateModule_basic", "IndexPut2DFloatAccumulateModule_basic", - "IndexPut2DFloatNonAccumulateModule_basic", "IndexPut2DIntAccumulateModule_basic", - "IndexPut2DIntNonAccumulateModule_basic", "IndexPut3DFloatAccumulateModule_basic", - "IndexPut3DFloatNonAccumulateModule_basic", "IndexPut3DIntAccumulateModule_basic", - "IndexPut3DIntNonAccumulateModule_basic", "IndexPutHackedTwin1DFloatAccumulateModule_basic", - "IndexPutHackedTwin1DFloatNonAccumulateModule_basic", "IndexPutHackedTwin1DIntAccumulateModule_basic", - "IndexPutHackedTwin1DIntNonAccumulateModule_basic", "IndexPutHackedTwin2DFloatAccumulateModule_basic", - "IndexPutHackedTwin2DFloatNonAccumulateModule_basic", "IndexPutHackedTwin2DIntAccumulateModule_basic", - "IndexPutHackedTwin2DIntNonAccumulateModule_basic", "IndexPutHackedTwin3DFloatAccumulateModule_basic", - "IndexPutHackedTwin3DFloatNonAccumulateModule_basic", "IndexPutHackedTwin3DIntAccumulateModule_basic", - "IndexPutHackedTwin3DIntNonAccumulateModule_basic", "IndexPutImpl1DFloatAccumulateModule_basic", - "IndexPutImpl1DFloatNonAccumulateModule_basic", "IndexPutImpl1DIntAccumulateModule_basic", - "IndexPutImpl1DIntNonAccumulateModule_basic", "IndexPutImpl2DFloatAccumulateModule_basic", - "IndexPutImpl2DFloatNonAccumulateModule_basic", "IndexPutImpl2DImplicitModule_basic", "IndexPutImpl2DIndexModule_basic", "IndexPutImpl2DNoneIndexStaticModule_basic", "IndexPutImpl3DFloatAccumulateModule_basic", - "IndexPutImpl3DFloatNonAccumulateModule_basic", "IndexPutImplIndexWithNoneModule_basic", - "IndexSelectRank0IdxModule_basic", "InterpolateDynamicModule_sizes_bilinear", "InterpolateDynamicModule_sizes_nearest", "InterpolateStaticModule_scales_bilinear_align_corners", @@ -3656,8 +3658,7 @@ "LinspaceDtypeModule_basic", "LinspaceEmptyModule_basic", "MaskedFillTensorFloatValueModule_basic", - "MatmulBroadcastBatchDim_basic", - "MatmulStaticBroadcast_basic", + "MaskedScatterStaticBasic_basic", "MaxPool1dCeilModeTrueModule_basic", "MaxPool1dModule_basic", "MaxPool2dCeilModeTrueModule_basic", @@ -3689,17 +3690,16 @@ "MaxPool3dWithIndicesNonDefaultStrideModule_basic", "MaxPool3dWithIndicesStaticModule_basic", "MeanDimEmptyDimModule_basic", - "MeanDimNoneDimModule_basic", - "MseLossMeanReductionModule_basic", - "MseLossSumReductionWithDifferentElemTypeModule_basic", + "MlGroupNormManualModule_basic", + "MlGroupNormModule_basic", + "MlLayerNormManualModule_basic", + "MlLayerNormModule_basic", "MulFloatModule_basic", "MulIntModule_basic", "NativeBatchNorm1DModule_basic", "NativeBatchNorm2DModule_basic", "NativeBatchNorm3DModule_basic", "NativeBatchNormNoneWeightModule_basic", - "NativeDropoutTrainModule_basic", - "NativeDropoutTrainStaticShapeModule_basic", "NativeGroupNormBackwardModule_basic", "NeFloatIntModule_basic", "NeIntModule_basic", @@ -3741,14 +3741,9 @@ "QuantizedReluInt8_basic", "QuantizedReluUint8_basic", "QuantizedSingleLayer_basic", - "RandIntDtypeModule_basic", - "RandIntLowDtypeModule_basic", "RandIntLowModule_basic", "RandIntModule_basic", "RandIntPinMemoryModule_basic", - "RandLikeDtypeModule_basic", - "RandLikeModule_basic", - "RandModule_basic", "RandnDtypeDeviceModule_basic", "RandnGeneratorF64Module_basic", "RandnGeneratorModule_basic", @@ -3760,9 +3755,7 @@ "ReduceL1NormComplexModule_basic", "ReduceL1NormWithDTypeModule_basic", "ReduceL2NormComplexModule_basic", - "ReduceL3NormAllDimsModule_basic", "ReduceL3NormKeepDimComplexModule_basic", - "ReduceL3NormKeepDimModule_basic", "ReduceMaxAlongDimUnsignedInt_basic", "ReduceMinAlongDimUnsignedInt_basic", "ReduceSumDimIntListEmptyDimModule_basic", @@ -3843,18 +3836,7 @@ "TensorsConcatPromoteDTypeModule_basic", "TensorsStackPromoteDTypeModule_basic", "TestMultipleTensorAndPrimitiveTypesReturn_basic", - "Threshold1dIntModule_basic", - "Threshold2dIntModule_basic", - "Threshold3dIntModule_basic", - "ThresholdBackward1dFloatModule_basic", - "ThresholdBackward1dIntModule_basic", - "ThresholdBackward1dMixedModule_basic", - "ThresholdBackward2dFloatModule_basic", - "ThresholdBackward2dIntModule_basic", "ThresholdBackward2dMixedModule_basic", - "ThresholdBackward3dFloatModule_basic", - "ThresholdBackward3dIntModule_basic", - "ThresholdBackward3dMixedModule_basic", "ToCopyWithDTypeFalsePinMemoryModule_basic", "ToCopyWithDTypeModule_basic", "TorchPrimLoopForLikeModule_basic", @@ -3863,10 +3845,6 @@ "TraceUnsignedIntModule_empty", "TypeConversionI1ToF64Module_basic", "TypeConversionI1ToI32Module_basic", - "UniformModule_basic", - "UniformNoCorrelationModule_basic", - "UniformStaticShapeModule_basic", - "UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic", "UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic", "UpSampleNearest2dBackwardScalesNone_basic", "UpSampleNearest2dBackward_basic", @@ -3875,9 +3853,6 @@ "UpSampleNearest2dStaticFactor_basic", "UpSampleNearest2dStaticSize_basic", "UpSampleNearest2d_basic", - "VarMeanBiasedModule_basic", - "VarMeanCorrectionNoneModule_basic", - "VarMeanUnbiasedModule_basic", "ViewCollapseDynamicWithAtenSizeIntModule_basic", "ViewSizeFromOtherTensor_basic", "VisionTransformerModule_basic", @@ -3894,6 +3869,15 @@ } ONNX_TOSA_XFAIL_SET = { + "ElementwiseRreluWithNoiseEvalModule_basic", + "ElementwiseRreluWithNoiseEvalStaticModule_basic", + "ElementwiseRreluWithNoiseTrainModule_basic", + "ElementwiseRreluWithNoiseTrainStaticModule_basic", + "RreluWithNoiseBackwardEvalModule_basic", + "RreluWithNoiseBackwardEvalStaticModule_basic", + "RreluWithNoiseBackwardTrainModule_basic", + "RreluWithNoiseBackwardTrainStaticModule_basic", + "RreluWithNoiseForwardBackwardModule_basic", "Unfold_Module_Dynamic_basic", "Unfold_Module_Rank_4", "Unfold_Module_Rank_Zero_Size_Zero_basic", @@ -3937,12 +3921,10 @@ "Conv_Transpose2dStaticModule_basic", "Conv_Transpose3dModule_basic", "Conv_Transpose3dStaticModule_basic", - "EinsumStaticModule_basic", "ElementwiseFmaxModule_basic", "ElementwiseFminModule_basic", "ElementwiseGeluApproximateTanhModule_basic", "ElementwiseIntTensorLtFloatTensorModule_basic", - "ElementwiseNanToNumWithNoneModule_Basic", "ElementwiseRad2DegIntModule_basic", "ElementwiseRad2DegModule_basic", "ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic", @@ -4106,7 +4088,6 @@ "BoolIntConstantModule_basic", "BoolIntFalseModule_basic", "BoolIntTrueModule_basic", - "BoolTensorHandleSignless_basic", "BroadcastDynamicDimModule_basic", "BroadcastToModule_basic", "BucketizeTensorFloatModule_basic", @@ -4123,10 +4104,6 @@ "CollapseRank1DynamicModule_basic", "CollapseStaticModule_basic", "ConstantBoolParameterModule_basic", - "ConstantPad2dStaticModule_basic", - "ConstantPadNdModule_basic", - "ConstantPadNdPartialStaticModule_basic", - "ConstantPadNdStaticModule_basic", "ContainsIntList_False", "ContainsIntList_True", "Conv1dModule_basic", @@ -4220,9 +4197,7 @@ "ElementwiseAtenFloorDivideTensorPositiveModule_basic", "ElementwiseAtenIsneginfOpModule_basic", "ElementwiseAtenIsposinfOpModule_basic", - "ElementwiseAtenLogicalAndOpModule_basic", "ElementwiseAtenLogicalAndOpPromoteBroadcastModule_basic", - "ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic", "ElementwiseAtenLogicalNotOpPromoteModule_basic", "ElementwiseAtenLogicalOrOpBrodcastModule_basic", "ElementwiseAtenLogicalOrOpDiffArgs1Module_basic", @@ -4254,7 +4229,6 @@ "ElementwiseCoshModule_basic", "ElementwiseDequantizePerChannelModule_basic", "ElementwiseDequantizePerTensorModule_basic", - "ElementwiseDivScalarRoundingModeFloorIntStaticModule_basic", "ElementwiseDivScalarRoundingModeTruncModule_basic", "ElementwiseDivScalarRoundingModeTruncStaticModule_basic", "ElementwiseDivTensorFloatModule_basic", @@ -4291,7 +4265,6 @@ "ElementwiseMulTensorComplexModule_basic", "ElementwiseMulTensorFloatModule_basic", "ElementwiseMulTensorIntModule_basic", - "ElementwiseNanToNumModule_Basic", "ElementwiseOrTensorModule_basic", "ElementwiseOrTensorStaticShapeModule_basic", "ElementwiseQuantizePerTensorModule_basic", @@ -4579,8 +4552,6 @@ "OnesLikeModule_falsePinMemory", "OnesLikeModule_float", "OnesLikeModule_int", - "PadModule_basic", - "PadWithNoneValModule_basic", "PermuteNegativeIndexModule_basic", "PixelShuffleModuleFullDynamic_basic", "PixelShuffleModuleSpatiallyDynamic_basic", @@ -4688,7 +4659,6 @@ "ReflectionPad2dModule_Right", "ReflectionPad2dModule_Top", "ReflectionPad2dModule_basic", - "RepeatModule_basic", "ReplicationPad2dModule_basic", "ReplicationPad2dModule_bottom0", "ReplicationPad2dModule_left0", diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index ed6f909c4a1b..80dcc0ac7937 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -2159,7 +2159,85 @@ func.func @torch.aten.diag_embed$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !t // CHECK: return %[[RESULT]] : !torch.vtensor<[4,2],si64> func.func @torch.aten.index.Tensor_hacked_twin(%arg0: !torch.vtensor<[2,4,2],si64>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[4,2],si64> { - %0 = torch.prim.ListConstruct %arg1 : (!torch.vtensor<[],si64>) -> !torch.list - %1 = torch.aten.index.Tensor_hacked_twin %arg0, %0 : !torch.vtensor<[2,4,2],si64>, !torch.list -> !torch.vtensor<[4,2],si64> - return %1 : !torch.vtensor<[4,2],si64> - } + %0 = torch.prim.ListConstruct %arg1 : (!torch.vtensor<[],si64>) -> !torch.list + %1 = torch.aten.index.Tensor_hacked_twin %arg0, %0 : !torch.vtensor<[2,4,2],si64>, !torch.list -> !torch.vtensor<[4,2],si64> + return %1 : !torch.vtensor<[4,2],si64> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.threshold_backward$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4],si64>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4],si64>) -> !torch.vtensor<[4],si64> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4],si64> -> tensor<4xi64> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4],si64> -> tensor<4xi64> +// CHECK: %[[VAL_4:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1> : tensor<4xi64>}> : () -> tensor<4xi64> +// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<0> : tensor}> : () -> tensor +// CHECK: %[[VAL_7:.*]] = tosa.greater_equal %[[VAL_5]], %[[VAL_2]] : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi1> +// CHECK: %[[VAL_8:.*]] = tosa.select %[[VAL_7]], %[[VAL_6]], %[[VAL_3]] : (tensor<4xi1>, tensor, tensor<4xi64>) -> tensor<4xi64> +// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<4xi64> -> !torch.vtensor<[4],si64> +// CHECK: return %[[VAL_9]] : !torch.vtensor<[4],si64> +// CHECK: } +func.func @torch.aten.threshold_backward$basic(%arg0: !torch.vtensor<[4],si64>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[4],si64> { + %int1 = torch.constant.int 1 + %0 = torch.aten.threshold_backward %arg0, %arg1, %int1 : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + return %0 : !torch.vtensor<[4],si64> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.threshold$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,5],si64>) -> !torch.vtensor<[4,5],si64> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,5],si64> -> tensor<4x5xi64> +// CHECK: %[[VAL_2:.*]] = torch.constant.float 5.000000e-01 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1xi64>}> : () -> tensor<1x1xi64> +// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<2> : tensor<1x1xi64>}> : () -> tensor<1x1xi64> +// CHECK: %[[VAL_6:.*]] = tosa.greater %[[VAL_1]], %[[VAL_4]] : (tensor<4x5xi64>, tensor<1x1xi64>) -> tensor<4x5xi1> +// CHECK: %[[VAL_7:.*]] = tosa.select %[[VAL_6]], %[[VAL_1]], %[[VAL_5]] : (tensor<4x5xi1>, tensor<4x5xi64>, tensor<1x1xi64>) -> tensor<4x5xi64> +// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<4x5xi64> -> !torch.vtensor<[4,5],si64> +// CHECK: return %[[VAL_8]] : !torch.vtensor<[4,5],si64> +// CHECK: } +func.func @torch.aten.threshold$basic(%arg0: !torch.vtensor<[4,5],si64>) -> !torch.vtensor<[4,5],si64> { + %float5.000000e-01 = torch.constant.float 5.000000e-01 + %int2 = torch.constant.int 2 + %0 = torch.aten.threshold %arg0, %float5.000000e-01, %int2 : !torch.vtensor<[4,5],si64>, !torch.float, !torch.int -> !torch.vtensor<[4,5],si64> + return %0 : !torch.vtensor<[4,5],si64> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.logical_and$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,5],i1>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4,5],i1>) -> !torch.vtensor<[4,5],i1> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4,5],i1> -> tensor<4x5xi1> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,5],i1> -> tensor<4x5xi1> +// CHECK: %[[VAL_4:.*]] = tosa.logical_and %[[VAL_3]], %[[VAL_2]] : (tensor<4x5xi1>, tensor<4x5xi1>) -> tensor<4x5xi1> +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<4x5xi1> -> !torch.vtensor<[4,5],i1> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[4,5],i1> +// CHECK: } +func.func @torch.aten.logical_and$basic(%arg0: !torch.vtensor<[4,5],i1>, %arg1: !torch.vtensor<[4,5],i1>) -> !torch.vtensor<[4,5],i1> { + %0 = torch.aten.logical_and %arg0, %arg1 : !torch.vtensor<[4,5],i1>, !torch.vtensor<[4,5],i1> -> !torch.vtensor<[4,5],i1> + return %0 : !torch.vtensor<[4,5],i1> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.uniform$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],f64>) -> (!torch.vtensor<[3,4],f64>, !torch.vtensor<[3,4],f64>) { +// CHECK: %[[VAL_1:.*]] = torch.constant.float 1.000000e+00 +// CHECK: %[[VAL_2:.*]] = torch.constant.float 1.000000e+01 +// CHECK: %[[VAL_3:.*]] = torch.constant.none +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<{{\[\[}}1.00007045, 2.18384027, 7.80044794, 5.12785149], [5.79490519, 2.97063255, 1.42340159, 7.10978221], [7.11366796, 9.41223621, 4.45151854, 5.67474747]]> : tensor<3x4xf32>}> : () -> tensor<3x4xf32> +// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_4]] : (tensor<3x4xf32>) -> tensor<3x4xf64> +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf64> -> !torch.vtensor<[3,4],f64> +// CHECK: return %[[VAL_6]], %[[VAL_6]] : !torch.vtensor<[3,4],f64>, !torch.vtensor<[3,4],f64> +// CHECK: } +func.func @torch.aten.uniform$basic(%arg0: !torch.vtensor<[3,4],f64>) -> (!torch.vtensor<[3,4],f64>, !torch.vtensor<[3,4],f64>) { + %float1.000000e00 = torch.constant.float 1.000000e+00 + %float1.000000e01 = torch.constant.float 1.000000e+01 + %none = torch.constant.none + %0 = torch.aten.uniform %arg0, %float1.000000e00, %float1.000000e01, %none : !torch.vtensor<[3,4],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[3,4],f64> + return %0, %0 : !torch.vtensor<[3,4],f64>, !torch.vtensor<[3,4],f64> +}