diff --git a/externals/stablehlo b/externals/stablehlo index 271e8634de18..ab92adeda911 160000 --- a/externals/stablehlo +++ b/externals/stablehlo @@ -1 +1 @@ -Subproject commit 271e8634de184fbfafd677d3876170feb6d08c97 +Subproject commit ab92adeda9119a6c3914cd42367b0a2b70765e91 diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index bd5c57fac3ba..f7fac538068a 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -35,6 +35,108 @@ static LogicalResult createTorchTransposeOp(ConversionPatternRewriter &rewriter, return success(); } +namespace { +LogicalResult windowFunctionImpl(OpBinder binder, + ConversionPatternRewriter &rewriter, + Value size, Value a0, Value a1, Value a2, + Torch::ValueTensorType resultType, + int64_t output_datatype, int64_t periodic) { + + Location loc = binder.getLoc(); + ImplicitLocOpBuilder b(loc, rewriter); + + double isPeriodicFp = static_cast(periodic); + + Value zero = b.create(rewriter.getF64FloatAttr(0.0)); + Value one = b.create(rewriter.getF64FloatAttr(1.0)); + Value two = b.create(rewriter.getF64FloatAttr(2.0)); + + constexpr double pi = llvm::numbers::pi; + Value tau = b.create( + rewriter.getFloatAttr(rewriter.getF64Type(), 2.0 * pi)); + + Value noneVal = b.create(); + Value cstFalse = b.create(false); + Value float32Type = b.create( + rewriter.getI64IntegerAttr(/*float32Type*/ 6)); + + // Create an f32 ValueTensorType with thse same size as size, the + // operand + auto shapeOfOperand = + size.getType().dyn_cast().getOptionalSizes(); + auto f32ResultType = rewriter.getType( + shapeOfOperand, rewriter.getF32Type()); + Value periodicSizeFloat = b.create( + f32ResultType, size, float32Type, cstFalse, cstFalse, noneVal); + Value symmetricSizeFloat = b.create( + periodicSizeFloat.getType(), periodicSizeFloat, one, one); + + Value isPeriodic = + b.create(rewriter.getF64FloatAttr(isPeriodicFp)); + Value isSymmetricFloat = b.create( + rewriter.getF64FloatAttr(1.0 - isPeriodicFp)); + + Value periodicComponent = b.create( + periodicSizeFloat.getType(), periodicSizeFloat, isPeriodic); + Value symmetricComponent = b.create( + symmetricSizeFloat.getType(), symmetricSizeFloat, isSymmetricFloat); + Value sizeFloat = b.create( + symmetricComponent.getType(), symmetricComponent, periodicComponent, one); + + // Here, size can be used in the place of periodicSizeFloat, as the + // latter is just a float representation of the former. + Value scalarLimit = getItemOp(binder, rewriter, size); + + Value rangeArr = b.create( + resultType, zero, scalarLimit, one, noneVal, noneVal, noneVal, noneVal); + + Value rangeTimesTau = + b.create(resultType, rangeArr, tau); + Value rangeAngular = + b.create(resultType, rangeTimesTau, sizeFloat); + Value twoRangeAngular = + b.create(resultType, rangeAngular, two); + + Value cosRangeAngular = b.create(resultType, rangeAngular); + Value cosTwoRangeAngular = + b.create(resultType, twoRangeAngular); + + Value a1Component = + b.create(resultType, cosRangeAngular, a1); + Value a2Component = + b.create(resultType, cosTwoRangeAngular, a2); + + // AtenSubScalarOp actually requires a tensor operand as the LHS, that + // is, operand #1. Therefore, to avoid errors, the onnx implementation + // has been modified. a1 has been changed to negative half, and the + // AtenSubScalarOp has been replaced with AtenAddScalarOp, as the add + // operation is commutative. + Value subA1Component = + b.create(resultType, a1Component, a0, one); + Value result = b.create(resultType, subA1Component, + a2Component, one); + + std::optional dtypeIntTorch = + onnxDtypeIntToTorchDtypeInt(output_datatype); + if (!dtypeIntTorch.has_value()) { + return rewriter.notifyMatchFailure( + binder.op, "unimplemented support for the given dtype conversion"); + } + Value outputDtype = b.create( + rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + dtypeIntTorch.value())); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, result, outputDtype, + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/noneVal); + + return success(); +} + +} // namespace + // Simple rewrites for the default domain. // See: https://onnx.ai/onnx/operators/ // For operators that are effectively version invariant, we register with @@ -198,29 +300,17 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, operand); return success(); }); - patterns.onOp( - "Asinh", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value operand; - if (binder.tensorOperand(operand) || - binder.tensorResultType(resultType)) - return failure(); - - // log(x + sqrt(x**2 + 1)) - Value square = rewriter.create( - binder.getLoc(), resultType, operand); - Value cstOne = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(1)); - Value add0 = rewriter.create( - binder.getLoc(), resultType, square, cstOne, cstOne); - Value sqrt = rewriter.create(binder.getLoc(), - resultType, add0); - Value add1 = rewriter.create( - binder.getLoc(), resultType, operand, sqrt, cstOne); - rewriter.replaceOpWithNewOp(binder.op, resultType, - add1); - return success(); - }); + patterns.onOp("Asinh", 9, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); patterns.onOp("Atan", 7, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; @@ -232,33 +322,17 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, operand); return success(); }); - patterns.onOp( - "Atanh", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value operand; - if (binder.tensorOperand(operand) || - binder.tensorResultType(resultType)) - return failure(); - - // 1/2 * log((1 + x) / (1 - x)) - Value cstOne = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(1)); - Value add = rewriter.create( - binder.getLoc(), resultType, operand, cstOne, cstOne); - Value neg = rewriter.create(binder.getLoc(), - resultType, operand); - Value sub = rewriter.create( - binder.getLoc(), resultType, neg, cstOne, cstOne); - Value div = rewriter.create( - binder.getLoc(), resultType, add, sub); - Value log = - rewriter.create(binder.getLoc(), resultType, div); - Value cstTwo = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(2)); - rewriter.replaceOpWithNewOp( - binder.op, resultType, log, cstTwo); - return success(); - }); + patterns.onOp("Atanh", 9, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); patterns.onOp("Acos", 7, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; @@ -270,29 +344,17 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, operand); return success(); }); - patterns.onOp( - "Acosh", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value operand; - if (binder.tensorOperand(operand) || - binder.tensorResultType(resultType)) - return failure(); - - // log(x + sqrt(x**2 - 1)) - Value square = rewriter.create( - binder.getLoc(), resultType, operand); - Value cstOne = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(1)); - Value sub = rewriter.create( - binder.getLoc(), resultType, square, cstOne, cstOne); - Value sqrt = rewriter.create(binder.getLoc(), - resultType, sub); - Value add = rewriter.create( - binder.getLoc(), resultType, operand, sqrt, cstOne); - rewriter.replaceOpWithNewOp(binder.op, resultType, - add); - return success(); - }); + patterns.onOp("Acosh", 9, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); patterns.onOp("BatchNormalization", 15, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; @@ -1388,31 +1450,17 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, operand); return success(); }); - patterns.onOp( - "Cosh", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value operand; - if (binder.tensorOperand(operand) || - binder.tensorResultType(resultType)) - return failure(); - - // 1/2 * (exp(x) + exp(-x)) - Value x = rewriter.create(binder.getLoc(), resultType, - operand); - Value neg = rewriter.create(binder.getLoc(), - resultType, operand); - Value y = - rewriter.create(binder.getLoc(), resultType, neg); - Value cstOne = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(1)); - Value z = rewriter.create( - binder.getLoc(), resultType, x, y, cstOne); - Value cstTwo = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(2)); - rewriter.replaceOpWithNewOp( - binder.op, resultType, z, cstTwo); - return success(); - }); + patterns.onOp("Cosh", 9, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); patterns.onOp( "CumSum", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Location loc = binder.getLoc(); @@ -2252,114 +2300,84 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.tensorResultType(resultType)) { return failure(); } - double isPeriodicFp = static_cast(periodic); + + Location loc = binder.getLoc(); Value a0 = rewriter.create( - binder.getLoc(), - rewriter.getFloatAttr(rewriter.getF64Type(), 0.42)); + loc, rewriter.getF64FloatAttr(0.42)); Value a1 = rewriter.create( - binder.getLoc(), - rewriter.getFloatAttr(rewriter.getF64Type(), -0.5)); + loc, rewriter.getF64FloatAttr(-0.5)); Value a2 = rewriter.create( - binder.getLoc(), - rewriter.getFloatAttr(rewriter.getF64Type(), 0.08)); - Value zero = rewriter.create( - binder.getLoc(), rewriter.getF64FloatAttr(0.0)); - Value one = rewriter.create( - binder.getLoc(), rewriter.getF64FloatAttr(1.0)); - Value two = rewriter.create( - binder.getLoc(), rewriter.getF64FloatAttr(2.0)); - - constexpr double pi = llvm::numbers::pi; - Value tau = rewriter.create( - binder.getLoc(), - rewriter.getFloatAttr(rewriter.getF64Type(), 2.0 * pi)); + loc, rewriter.getF64FloatAttr(0.08)); - Value noneVal = rewriter.create(binder.getLoc()); - Value cstFalse = - rewriter.create(binder.getLoc(), false); - Value float32Type = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(/*float32Type*/ 6)); - - // Create an f32 ValueTensorType with thse same size as size, the - // operand - auto shapeOfOperand = size.getType() - .dyn_cast() - .getOptionalSizes(); - auto f32ResultType = rewriter.getType( - shapeOfOperand, rewriter.getF32Type()); - Value periodicSizeFloat = rewriter.create( - binder.getLoc(), f32ResultType, size, float32Type, cstFalse, - cstFalse, noneVal); - Value symmetricSizeFloat = rewriter.create( - binder.getLoc(), periodicSizeFloat.getType(), periodicSizeFloat, - one, one); - - Value isPeriodic = rewriter.create( - binder.getLoc(), rewriter.getF64FloatAttr(isPeriodicFp)); - Value isSymmetricFloat = rewriter.create( - binder.getLoc(), rewriter.getF64FloatAttr(1.0 - isPeriodicFp)); - - Value periodicComponent = rewriter.create( - binder.getLoc(), periodicSizeFloat.getType(), periodicSizeFloat, - isPeriodic); - Value symmetricComponent = rewriter.create( - binder.getLoc(), symmetricSizeFloat.getType(), symmetricSizeFloat, - isSymmetricFloat); - Value sizeFloat = rewriter.create( - binder.getLoc(), symmetricComponent.getType(), symmetricComponent, - periodicComponent, one); - - // Here, size can be used in the place of periodicSizeFloat, as the - // latter is just a float representation of the former. - Value scalarLimit = getItemOp(binder, rewriter, size); - - Value rangeArr = rewriter.create( - binder.getLoc(), resultType, zero, scalarLimit, one, noneVal, - noneVal, noneVal, noneVal); - - Value rangeTimesTau = rewriter.create( - binder.getLoc(), resultType, rangeArr, tau); - Value rangeAngular = rewriter.create( - binder.getLoc(), resultType, rangeTimesTau, sizeFloat); - Value twoRangeAngular = rewriter.create( - binder.getLoc(), resultType, rangeAngular, two); - - Value cosRangeAngular = rewriter.create( - binder.getLoc(), resultType, rangeAngular); - Value cosTwoRangeAngular = rewriter.create( - binder.getLoc(), resultType, twoRangeAngular); - - Value a1Component = rewriter.create( - binder.getLoc(), resultType, cosRangeAngular, a1); - Value a2Component = rewriter.create( - binder.getLoc(), resultType, cosTwoRangeAngular, a2); - - // AtenSubScalarOp actually requires a tensor operand as the LHS, that - // is, operand #1. Therefore, to avoid errors, the onnx implementation - // has been modified. a1 has been changed to negative half, and the - // AtenSubScalarOp has been replaced with AtenAddScalarOp, as the add - // operation is commutative. - Value subA1Component = rewriter.create( - binder.getLoc(), resultType, a1Component, a0, one); - Value result = rewriter.create( - binder.getLoc(), resultType, subA1Component, a2Component, one); + auto windowFunctionResult = + windowFunctionImpl(binder, rewriter, size, a0, a1, a2, resultType, + output_datatype, periodic); - std::optional dtypeIntTorch = - onnxDtypeIntToTorchDtypeInt(output_datatype); - if (!dtypeIntTorch.has_value()) { - return rewriter.notifyMatchFailure( - binder.op, - "unimplemented support for the given dtype conversion"); + if (failed(windowFunctionResult)) + return failure(); + + return success(); + }); + + patterns.onOp( + "HannWindow", 17, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Value size; + Torch::ValueTensorType resultType; + int64_t periodic, output_datatype; + if (binder.tensorOperand(size) || + binder.s64IntegerAttr(output_datatype, "output_datatype", 1) || + binder.s64IntegerAttr(periodic, "periodic", 1) || + binder.tensorResultType(resultType)) { + return failure(); } - Value outputDtype = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), - dtypeIntTorch.value())); - rewriter.replaceOpWithNewOp( - binder.op, resultType, result, outputDtype, - /*non_blocking=*/cstFalse, /*copy=*/cstFalse, - /*memory_format=*/noneVal); + Location loc = binder.getLoc(); + Value a0 = rewriter.create( + loc, rewriter.getF64FloatAttr(0.5)); + Value a1 = rewriter.create( + loc, rewriter.getF64FloatAttr(-0.5)); + Value a2 = rewriter.create( + loc, rewriter.getF64FloatAttr(0.0)); + + auto windowFunctionResult = + windowFunctionImpl(binder, rewriter, size, a0, a1, a2, resultType, + output_datatype, periodic); + + if (failed(windowFunctionResult)) + return failure(); + + return success(); + }); + + patterns.onOp( + "HammingWindow", 17, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Value size; + Torch::ValueTensorType resultType; + int64_t periodic, output_datatype; + if (binder.tensorOperand(size) || + binder.s64IntegerAttr(output_datatype, "output_datatype", 1) || + binder.s64IntegerAttr(periodic, "periodic", 1) || + binder.tensorResultType(resultType)) { + return failure(); + } + + Location loc = binder.getLoc(); + Value a0 = rewriter.create( + loc, rewriter.getF64FloatAttr(0.543478)); + Value a1 = rewriter.create( + loc, rewriter.getF64FloatAttr(-0.456522)); + Value a2 = rewriter.create( + loc, rewriter.getF64FloatAttr(0.0)); + + auto windowFunctionResult = + windowFunctionImpl(binder, rewriter, size, a0, a1, a2, resultType, + output_datatype, periodic); + + if (failed(windowFunctionResult)) + return failure(); + return success(); }); } diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 739b6e16f9ce..7a6cae2e4eb6 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -851,9 +851,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.tensorResultType(resultType)) return failure(); - Torch::ValueTensorType inputType = - operand.getType().cast(); - Value vAlpha = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getFloatAttr(rewriter.getF64Type(), alpha)); @@ -862,31 +859,12 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.getLoc(), rewriter.getType(), rewriter.getFloatAttr(rewriter.getF64Type(), gamma)); - Value cstOne = rewriter.create( + Value vInputScale = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getFloatAttr(rewriter.getF64Type(), 1.0)); - Value cstNone = rewriter.create(binder.getLoc()); - Value zeroTensor = rewriter.create( - binder.getLoc(), resultType, operand, cstNone, cstNone, cstNone, - cstNone, cstNone); - Value exp = rewriter.create(binder.getLoc(), - resultType, operand); - Value expMulAlpha = rewriter.create( - binder.getLoc(), resultType, exp, vAlpha); - Value expMulAlphaSubAlpha = rewriter.create( - binder.getLoc(), resultType, expMulAlpha, vAlpha, cstOne); - Value neg = rewriter.create( - binder.getLoc(), resultType, expMulAlphaSubAlpha, vScale); - Value pos = rewriter.create( - binder.getLoc(), resultType, operand, vScale); - Type compareType = inputType.getWithSizesAndDtype( - inputType.getOptionalSizes(), rewriter.getI1Type()); - Value xLessThanZero = rewriter.create( - binder.getLoc(), compareType, operand, zeroTensor); - - rewriter.replaceOpWithNewOp( - binder.op, resultType, xLessThanZero, neg, pos); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, vAlpha, vScale, vInputScale); return success(); }); patterns.onOp("ReduceL1", 1, @@ -962,6 +940,32 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( /*memory_format=*/noneVal); return success(); }); + patterns.onOp("ReduceLogSum", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value data; + int64_t keepDims, noop_with_empty_axes; + if (binder.tensorOperandAtIndex(data, 0) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(keepDims, "keepdims", 1) || + binder.s64IntegerAttr(noop_with_empty_axes, + "noop_with_empty_axes", 0)) + return failure(); + + auto reducedSumBool = + reducedSumImpl(binder, rewriter, data, resultType, + /*storeValue=*/data, keepDims, + noop_with_empty_axes, true); + + if (failed(reducedSumBool)) + return rewriter.notifyMatchFailure( + binder.op, + "Failed to perform sum operation on square of operand"); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, data); + return success(); + }); patterns.onOp("ReduceSum", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; @@ -978,7 +982,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( /*storeValue=*/data, keepDims, noop_with_empty_axes, false); }); - patterns.onOp("ReduceLogSum", 1, + patterns.onOp("ReduceSumSquare", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value data; @@ -990,19 +994,13 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( "noop_with_empty_axes", 0)) return failure(); - auto reducedSumBool = - reducedSumImpl(binder, rewriter, data, resultType, - /*storeValue=*/data, keepDims, - noop_with_empty_axes, true); - - if (failed(reducedSumBool)) - return rewriter.notifyMatchFailure( - binder.op, - "Failed to perform sum operation on square of operand"); + Value dataSquare = rewriter.create( + binder.getLoc(), data.getType(), data, data); - rewriter.replaceOpWithNewOp( - binder.op, resultType, data); - return success(); + return reducedSumImpl(binder, rewriter, dataSquare, + resultType, + /*storeValue=*/data, keepDims, + noop_with_empty_axes, false); }); patterns.onOp( "ReduceMean", 1, @@ -1441,31 +1439,18 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( return success(); }); - patterns.onOp( - "Sinh", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value operand; - if (binder.tensorOperand(operand) || - binder.tensorResultType(resultType)) - return failure(); + patterns.onOp("Sinh", 9, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); - // 1/2 * (exp(x) – exp(-x)) - Value x = rewriter.create(binder.getLoc(), resultType, - operand); - Value neg = rewriter.create(binder.getLoc(), - resultType, operand); - Value y = - rewriter.create(binder.getLoc(), resultType, neg); - Value cstOne = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(1)); - Value z = rewriter.create( - binder.getLoc(), resultType, x, y, cstOne); - Value cstTwo = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(2)); - rewriter.replaceOpWithNewOp( - binder.op, resultType, z, cstTwo); - return success(); - }); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); // split with fixed-size parts // Arguments: @@ -2777,4 +2762,61 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( /*generator=*/cstNone); return success(); }); + patterns.onOp( + "SoftmaxCrossEntropyLoss", 12, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + int64_t ignoreIndex; + std::string reduction; + SmallVector shape; + Value scores, labels, weight; + if (binder.tensorOperandAtIndex(scores, 0) || + binder.tensorOperandAtIndex(labels, 1) || + binder.s64IntegerAttr(ignoreIndex, "ignore_index ", -100) || + binder.customOpNameStringAttr(reduction, "reduction", "mean") || + binder.tensorResultTypeAtIndex(resultType, 0)) { + return failure(); + } + + if (binder.tensorOperandAtIndex(weight, 2)) + weight = rewriter.create(binder.getLoc()); + + Value cstIgnoreIndex = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(ignoreIndex)); + + int64_t reductionInt = reduction == "none" ? 0 + : reduction == "mean" ? 1 + : 2; + Value cstReductionInt = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(reductionInt)); + + // The default PyTorch value for label smoothing is "0.0". + // Refer: + // https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html + Value cstLabelSmoothing = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getFloatAttr(rewriter.getF64Type(), 0.0)); + + Value loss = rewriter.create( + binder.getLoc(), resultType, scores, labels, weight, + cstReductionInt, cstIgnoreIndex, cstLabelSmoothing); + + if (binder.op->getNumResults() == 1) { + rewriter.replaceOp(binder.op, loss); + return success(); + } + + Torch::ValueTensorType resultTypeLogProb; + if (binder.tensorResultTypeAtIndex(resultTypeLogProb, 1)) + return failure(); + + Value dim = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value cstNone = rewriter.create(binder.getLoc()); + Value logProb = rewriter.create( + binder.getLoc(), resultTypeLogProb, scores, dim, /*dtype=*/cstNone); + + rewriter.replaceOp(binder.op, {loss, logProb}); + return success(); + }); } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index ff65f4a142d3..bf2092f23127 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2813,9 +2813,6 @@ "IndexPutHackedTwin3DFloatNonAccumulateModule_basic", "IndexPutHackedTwin3DIntAccumulateModule_basic", "IndexPutHackedTwin3DIntNonAccumulateModule_basic", - # Failure - onnx_lowering: onnx.SoftmaxCrossEntropyLoss - "CrossEntropyLossModule_basic", - "CrossEntropyLossNoReductionModule_basic", # RuntimeError: unsupported input type: Device "PrimsIotaModule_basic", # Failure - unknown diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py index 204ddf61674b..2a63c06bdc37 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py @@ -3,8 +3,6 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. -from typing import Union, Optional, Sequence - import numpy as np import torch import torch.utils._pytree as pytree @@ -12,15 +10,6 @@ from torch.export import ExportedProgram from torch_mlir import fx -from torch_mlir.compiler_utils import ( - run_pipeline_with_repro_report, - lower_mlir_module, - OutputType, -) -from torch_mlir.torchscript import ( - BACKEND_LEGAL_OPS, - _canon_extra_library, -) from torch_mlir_e2e_test.configs.utils import ( recursively_convert_to_numpy, recursively_convert_from_numpy, @@ -39,53 +28,6 @@ def refine_result_type(_result): raise ValueError(f"Unhandled return type {type(_result)}") -def jit( - prog: ExportedProgram, - func_name: str, - output_type: Union[str, "OutputType"] = OutputType.TORCH, - backend_legal_ops: Optional[Sequence[str]] = None, - extra_library=None, - verbose: bool = False, -): - if extra_library is None: - extra_library = [] - mlir_module = None - - extra_library_file_name = _canon_extra_library(extra_library) - output_type = OutputType.get(output_type) - if backend_legal_ops is not None: - if output_type != OutputType.TORCH: - raise Exception( - "`backend_legal_ops` is only valid with the " "`torch` output type" - ) - backend_legal_ops = list(sorted(set(backend_legal_ops))) - else: - backend_legal_ops = BACKEND_LEGAL_OPS.get(output_type, []) - - option_string = ( - "{backend-legal-ops=" - + ",".join(backend_legal_ops) - + " extra-library=" - + extra_library_file_name - + "}" - ) - - mlir_module = fx.export_and_import(prog, func_name=func_name) - assert mlir_module is not None - run_pipeline_with_repro_report( - mlir_module, - f"builtin.module(torch-simplification-pipeline)", - "Simplification pipeline for torch dialect", - ) - run_pipeline_with_repro_report( - mlir_module, - f"builtin.module(torch-function-to-torch-backend-pipeline{option_string})", - "Lowering TorchFX IR -> Torch Backend IR", - ) - - return lower_mlir_module(verbose, output_type, mlir_module) - - class FxImporterTestConfig(TestConfig): """TestConfig that runs the torch.nn.Module with Fx Importer""" @@ -100,11 +42,11 @@ def compile(self, program: torch.nn.Module) -> torch.nn.Module: def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace: result: Trace = [] for item in trace: - prog = torch.export.export(artifact, tuple(item.inputs)) - module = jit( + prog: ExportedProgram = torch.export.export(artifact, tuple(item.inputs)) + module = fx.export_and_import( prog, - func_name=artifact.__class__.__name__, output_type=self._output_type, + func_name=artifact.__class__.__name__, ) module = self._backend.compile(module) backend_module = self._backend.load(module) diff --git a/python/torch_mlir/fx.py b/python/torch_mlir/fx.py index 651ccae673a6..8d5c5cb1125c 100644 --- a/python/torch_mlir/fx.py +++ b/python/torch_mlir/fx.py @@ -16,11 +16,50 @@ from torch_mlir import ir from torch_mlir.dialects import torch as torch_d from torch_mlir.extras.fx_decomp_util import get_decomposition_table +from torch_mlir.compiler_utils import ( + OutputType, + run_pipeline_with_repro_report, + lower_mlir_module, +) + + +def _module_lowering( + verbose, + output_type, + torch_mod, + backend_legal_ops=None, + extra_library_file_name=None, +): + + if output_type == OutputType.TORCH: + if verbose: + print(torch_mod) + return torch_mod + # TODO: pass backend_legal_ops/extra_library_file_name by caller + if backend_legal_ops is None: + backend_legal_ops = [] + if extra_library_file_name is None: + extra_library_file_name = "" + option_string = ( + "{backend-legal-ops=" + + ",".join(backend_legal_ops) + + " extra-library=" + + extra_library_file_name + + "}" + ) + run_pipeline_with_repro_report( + torch_mod, + f"builtin.module(torch-function-to-torch-backend-pipeline{option_string})", + "Lowering TorchFX IR -> Torch Backend IR", + enable_ir_printing=verbose, + ) + return lower_mlir_module(verbose, output_type, torch_mod) def export_and_import( f: Union[nn.Module, ExportedProgram], *args, + output_type: Union[str, OutputType] = OutputType.TORCH, fx_importer: Optional[FxImporter] = None, dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, experimental_support_mutation: bool = False, @@ -28,6 +67,7 @@ def export_and_import( decomposition_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None, func_name: str = "main", enable_graph_printing: bool = False, + enable_ir_printing: bool = False, **kwargs, ): context = ir.Context() @@ -52,15 +92,19 @@ def export_and_import( else: fx_importer.import_frozen_program(prog, func_name=func_name) - return fx_importer.module + return _module_lowering( + enable_ir_printing, OutputType.get(output_type), fx_importer.module + ) def stateless_fx_import( gm: torch.fx.GraphModule, + output_type: Union[str, OutputType] = OutputType.TORCH, fx_importer: Optional[FxImporter] = None, hooks: Optional[FxImporterHooks] = None, model_name: str = "main", enable_graph_printing: bool = False, + enable_ir_printing: bool = False, ): if enable_graph_printing: gm.print_readable() @@ -69,4 +113,6 @@ def stateless_fx_import( if fx_importer is None: fx_importer = FxImporter(context=context, hooks=hooks) fx_importer.import_stateless_graph(gm.graph, func_name=model_name) - return fx_importer.module + return _module_lowering( + enable_ir_printing, OutputType.get(output_type), fx_importer.module + ) diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index a068acbf2941..e8266c04ffad 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -201,14 +201,7 @@ func.func @test_atan(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4, // CHECK-LABEL: @test_atanh func.func @test_atanh(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[C1:.*]] = torch.constant.int 1 - // CHECK: %[[ADD:.*]] = torch.aten.add.Scalar %arg0, %[[C1]], %[[C1]] : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[NEG:.*]] = torch.aten.neg %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[SUB:.*]] = torch.aten.add.Scalar %[[NEG]], %[[C1]], %[[C1]] : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[DIV:.*]] = torch.aten.div.Tensor %[[ADD]], %[[SUB]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[LOG:.*]] = torch.aten.log %[[DIV]] : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[C2:.*]] = torch.constant.int 2 - // CHECK: torch.aten.div.Scalar %[[LOG]], %[[C2]] : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,5],f32> + // CHECK: torch.aten.atanh %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> %0 = torch.operator "onnx.Atanh"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> return %0 : !torch.vtensor<[3,4,5],f32> } @@ -672,13 +665,7 @@ func.func @test_cos(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5 // CHECK-LABEL: @test_cosh_example func.func @test_cosh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[X:.+]] = torch.aten.exp %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> - // CHECK: %[[NEG:.+]] = torch.aten.neg %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> - // CHECK: %[[Y:.+]] = torch.aten.exp %[[NEG]] : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> - // CHECK: %[[C1:.+]] = torch.constant.int 1 - // CHECK: %[[ADD:.+]] = torch.aten.add.Tensor %[[X]], %[[Y]], %[[C1]] : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> - // CHECK: %[[C2:.+]] = torch.constant.int 2 - // CHECK: torch.aten.div.Scalar %[[ADD]], %[[C2]] : !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> + // CHECK: torch.aten.cosh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> %0 = torch.operator "onnx.Cosh"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> return %0 : !torch.vtensor<[3],f32> } @@ -687,13 +674,7 @@ func.func @test_cosh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[ // CHECK-LABEL: @test_cosh func.func @test_cosh(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[X:.+]] = torch.aten.exp %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[NEG:.+]] = torch.aten.neg %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[Y:.+]] = torch.aten.exp %[[NEG]] : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[C1:.+]] = torch.constant.int 1 - // CHECK: %[[ADD:.+]] = torch.aten.add.Tensor %[[X]], %[[Y]], %[[C1]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[C2:.+]] = torch.constant.int 2 - // CHECK: torch.aten.div.Scalar %[[ADD]], %[[C2]] : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,5],f32> + // CHECK: torch.aten.cosh %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> %0 = torch.operator "onnx.Cosh"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> return %0 : !torch.vtensor<[3,4,5],f32> } @@ -702,12 +683,7 @@ func.func @test_cosh(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4, // CHECK-LABEL: @test_acosh_example func.func @test_acosh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[SQUARE:.+]] = torch.aten.square %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> - // CHECK: %[[C1:.+]] = torch.constant.int 1 - // CHECK: %[[SUB:.+]] = torch.aten.sub.Scalar %[[SQUARE]], %[[C1]], %[[C1]] : !torch.vtensor<[3],f32>, !torch.int, !torch.int -> !torch.vtensor<[3],f32> - // CHECK: %[[SQRT:.+]] = torch.aten.sqrt %[[SUB]] : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> - // CHECK: %[[ADD:.+]] = torch.aten.add.Tensor %arg0, %[[SQRT]], %[[C1]] : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> - // CHECK: torch.aten.log %[[ADD]] : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + // CHECK: torch.aten.acosh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> %0 = torch.operator "onnx.Acosh"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> return %0 : !torch.vtensor<[3],f32> } @@ -716,12 +692,7 @@ func.func @test_acosh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor< // CHECK-LABEL: @test_acosh func.func @test_acosh(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[SQUARE:.+]] = torch.aten.square %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[C1:.+]] = torch.constant.int 1 - // CHECK: %[[SUB:.+]] = torch.aten.sub.Scalar %[[SQUARE]], %[[C1]], %[[C1]] : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[SQRT:.+]] = torch.aten.sqrt %[[SUB]] : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[ADD:.+]] = torch.aten.add.Tensor %arg0, %[[SQRT]], %[[C1]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,5],f32> - // CHECK: torch.aten.log %[[ADD]] : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: torch.aten.acosh %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> %0 = torch.operator "onnx.Acosh"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> return %0 : !torch.vtensor<[3,4,5],f32> } @@ -748,12 +719,7 @@ func.func @test_asin(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4, // CHECK-LABEL: @test_asinh_example func.func @test_asinh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[SQUARE:.+]] = torch.aten.square %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> - // CHECK: %[[C1:.+]] = torch.constant.int 1 - // CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %[[SQUARE]], %[[C1]], %[[C1]] : !torch.vtensor<[3],f32>, !torch.int, !torch.int -> !torch.vtensor<[3],f32> - // CHECK: %[[SQRT:.+]] = torch.aten.sqrt %[[ADD]] : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> - // CHECK: %[[ADD_0:.+]] = torch.aten.add.Tensor %arg0, %[[SQRT]], %[[C1]] : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> - // CHECK: torch.aten.log %[[ADD_0]] : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + // CHECK: torch.aten.asinh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> %0 = torch.operator "onnx.Asinh"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> return %0 : !torch.vtensor<[3],f32> } @@ -762,12 +728,7 @@ func.func @test_asinh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor< // CHECK-LABEL: @test_asinh func.func @test_asinh(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[SQUARE:.+]] = torch.aten.square %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[C1:.+]] = torch.constant.int 1 - // CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %[[SQUARE]], %[[C1]], %[[C1]] : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[SQRT:.+]] = torch.aten.sqrt %[[ADD]] : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[ADD_0:.+]] = torch.aten.add.Tensor %arg0, %[[SQRT]], %[[C1]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,5],f32> - // CHECK: torch.aten.log %[[ADD_0]] : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: torch.aten.asinh %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> %0 = torch.operator "onnx.Asinh"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> return %0 : !torch.vtensor<[3,4,5],f32> } @@ -2113,3 +2074,162 @@ func.func @test_blackmanwindow(%arg0: !torch.vtensor<[],si32>) -> !torch.vtensor %0 = torch.operator "onnx.BlackmanWindow"(%arg0) {torch.onnx.periodic = 1 : si64} : (!torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> return %0 : !torch.vtensor<[10],f32> } +// ----- + +// CHECK-LABEL: func.func @test_hannwindow +func.func @test_hannwindow(%arg0: !torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[A0:.+]] = torch.constant.float 5.000000e-01 + // CHECK-DAG: %[[A1:.+]] = torch.constant.float -5.000000e-01 + // CHECK-DAG: %[[A2:.+]] = torch.constant.float 0.000000e+00 + // CHECK-DAG: %[[ZERO:.+]] = torch.constant.float 0.000000e+00 + // CHECK-DAG: %[[ONE:.+]] = torch.constant.float 1.000000e+00 + // CHECK-DAG: %[[TWO:.+]] = torch.constant.float 2.000000e+00 + // CHECK-DAG: %[[TAU:.+]] = torch.constant.float 6.2831853071795862 + // CHECK-DAG: %[[NONE:.+]] = torch.constant.none + // CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false + // CHECK-DAG: %[[INT6_0:.+]] = torch.constant.int 6 + // CHECK-DAG: %[[CAST_0:.+]] = torch.aten.to.dtype %arg0, %[[INT6_0]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[SYMMSIZE:.+]] = torch.aten.sub.Scalar %[[CAST_0]], %[[ONE]], %[[ONE]] : !torch.vtensor<[],f32>, !torch.float, !torch.float -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[PERIODIC:.+]] = torch.constant.float 1.000000e+00 + // CHECK-DAG: %[[SYMMETRIC:.+]] = torch.constant.float 0.000000e+00 + // CHECK-DAG: %[[PERIODICCOMP:.+]] = torch.aten.mul.Scalar %[[CAST_0]], %[[PERIODIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[SYMMETRICCOMP:.+]] = torch.aten.mul.Scalar %[[SYMMSIZE]], %[[SYMMETRIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[SIZEFP:.+]] = torch.aten.add.Tensor %[[SYMMETRICCOMP]], %[[PERIODICCOMP]], %[[ONE]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[RANGELIM:.+]] = torch.aten.item %arg0 : !torch.vtensor<[],si32> -> !torch.int + // CHECK-DAG: %[[ARANGE:.+]] = torch.aten.arange.start_step %[[ZERO]], %[[RANGELIM]], %[[ONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.float, !torch.int, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[RANGETIMESTAU:.+]] = torch.aten.mul.Scalar %[[ARANGE]], %[[TAU]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[RANGEANGULAR:.+]] = torch.aten.div.Tensor %[[RANGETIMESTAU]], %[[SIZEFP]] : !torch.vtensor<[10],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[TWORANGEANGULAR:.+]] = torch.aten.mul.Scalar %[[RANGEANGULAR]], %[[TWO]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[COSRANGEANGULAR:.+]] = torch.aten.cos %[[RANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[TWOCOSRANGEANGULAR:.+]] = torch.aten.cos %[[TWORANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[A1COMP:.+]] = torch.aten.mul.Scalar %[[COSRANGEANGULAR]], %[[A1]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[A2COMP:.+]] = torch.aten.mul.Scalar %[[TWOCOSRANGEANGULAR]], %[[A2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[RES:.+]] = torch.aten.add.Scalar %[[A1COMP]], %[[A0]], %[[ONE]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[RESULT:.+]] = torch.aten.add.Tensor %[[RES]], %[[A2COMP]], %[[ONE]] : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[INT6_1:.+]] = torch.constant.int 6 + // CHECK: %[[CAST_1:.+]] = torch.aten.to.dtype %[[RESULT]], %[[INT6_1]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f32> + // CHECK: return %[[CAST_1]] : !torch.vtensor<[10],f32> + + %0 = torch.operator "onnx.HannWindow"(%arg0) {torch.onnx.periodic = 1 : si64} : (!torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> + return %0 : !torch.vtensor<[10],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_hannwindow_symmetric +func.func @test_hannwindow_symmetric(%arg0: !torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[A0:.+]] = torch.constant.float 5.000000e-01 + // CHECK-DAG: %[[A1:.+]] = torch.constant.float -5.000000e-01 + // CHECK-DAG: %[[A2:.+]] = torch.constant.float 0.000000e+00 + // CHECK-DAG: %[[ZERO:.+]] = torch.constant.float 0.000000e+00 + // CHECK-DAG: %[[ONE:.+]] = torch.constant.float 1.000000e+00 + // CHECK-DAG: %[[TWO:.+]] = torch.constant.float 2.000000e+00 + // CHECK-DAG: %[[TAU:.+]] = torch.constant.float 6.2831853071795862 + // CHECK-DAG: %[[NONE:.+]] = torch.constant.none + // CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false + // CHECK-DAG: %[[INT6_0:.+]] = torch.constant.int 6 + // CHECK-DAG: %[[CAST_0:.+]] = torch.aten.to.dtype %arg0, %[[INT6_0]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[SYMMSIZE:.+]] = torch.aten.sub.Scalar %[[CAST_0]], %[[ONE]], %[[ONE]] : !torch.vtensor<[],f32>, !torch.float, !torch.float -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[PERIODIC:.+]] = torch.constant.float 0.000000e+00 + // CHECK-DAG: %[[SYMMETRIC:.+]] = torch.constant.float 1.000000e+00 + // CHECK-DAG: %[[PERIODICCOMP:.+]] = torch.aten.mul.Scalar %[[CAST_0]], %[[PERIODIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[SYMMETRICCOMP:.+]] = torch.aten.mul.Scalar %[[SYMMSIZE]], %[[SYMMETRIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[SIZEFP:.+]] = torch.aten.add.Tensor %[[SYMMETRICCOMP]], %[[PERIODICCOMP]], %[[ONE]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[RANGELIM:.+]] = torch.aten.item %arg0 : !torch.vtensor<[],si32> -> !torch.int + // CHECK-DAG: %[[ARANGE:.+]] = torch.aten.arange.start_step %[[ZERO]], %[[RANGELIM]], %[[ONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.float, !torch.int, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[RANGETIMESTAU:.+]] = torch.aten.mul.Scalar %[[ARANGE]], %[[TAU]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[RANGEANGULAR:.+]] = torch.aten.div.Tensor %[[RANGETIMESTAU]], %[[SIZEFP]] : !torch.vtensor<[10],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[TWORANGEANGULAR:.+]] = torch.aten.mul.Scalar %[[RANGEANGULAR]], %[[TWO]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[COSRANGEANGULAR:.+]] = torch.aten.cos %[[RANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[TWOCOSRANGEANGULAR:.+]] = torch.aten.cos %[[TWORANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[A1COMP:.+]] = torch.aten.mul.Scalar %[[COSRANGEANGULAR]], %[[A1]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[A2COMP:.+]] = torch.aten.mul.Scalar %[[TWOCOSRANGEANGULAR]], %[[A2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[RES:.+]] = torch.aten.add.Scalar %[[A1COMP]], %[[A0]], %[[ONE]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[RESULT:.+]] = torch.aten.add.Tensor %[[RES]], %[[A2COMP]], %[[ONE]] : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[INT6_1:.+]] = torch.constant.int 6 + // CHECK: %[[CAST_1:.+]] = torch.aten.to.dtype %[[RESULT]], %[[INT6_1]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f32> + // CHECK: return %[[CAST_1]] : !torch.vtensor<[10],f32> + + %0 = torch.operator "onnx.HannWindow"(%arg0) {torch.onnx.periodic = 0 : si64} : (!torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> + return %0 : !torch.vtensor<[10],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_hammingwindow_symmetric +func.func @test_hammingwindow_symmetric(%arg0: !torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[A0:.+]] = torch.constant.float 5.434780e-01 + // CHECK-DAG: %[[A1:.+]] = torch.constant.float -4.565220e-01 + // CHECK-DAG: %[[A2:.+]] = torch.constant.float 0.000000e+00 + // CHECK-DAG: %[[ZERO:.+]] = torch.constant.float 0.000000e+00 + // CHECK-DAG: %[[ONE:.+]] = torch.constant.float 1.000000e+00 + // CHECK-DAG: %[[TWO:.+]] = torch.constant.float 2.000000e+00 + // CHECK-DAG: %[[TAU:.+]] = torch.constant.float 6.2831853071795862 + // CHECK-DAG: %[[NONE:.+]] = torch.constant.none + // CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false + // CHECK-DAG: %[[INT6_0:.+]] = torch.constant.int 6 + // CHECK-DAG: %[[CAST_0:.+]] = torch.aten.to.dtype %arg0, %[[INT6_0]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[SYMMSIZE:.+]] = torch.aten.sub.Scalar %[[CAST_0]], %[[ONE]], %[[ONE]] : !torch.vtensor<[],f32>, !torch.float, !torch.float -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[PERIODIC:.+]] = torch.constant.float 0.000000e+00 + // CHECK-DAG: %[[SYMMETRIC:.+]] = torch.constant.float 1.000000e+00 + // CHECK-DAG: %[[PERIODICCOMP:.+]] = torch.aten.mul.Scalar %[[CAST_0]], %[[PERIODIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[SYMMETRICCOMP:.+]] = torch.aten.mul.Scalar %[[SYMMSIZE]], %[[SYMMETRIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[SIZEFP:.+]] = torch.aten.add.Tensor %[[SYMMETRICCOMP]], %[[PERIODICCOMP]], %[[ONE]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[RANGELIM:.+]] = torch.aten.item %arg0 : !torch.vtensor<[],si32> -> !torch.int + // CHECK-DAG: %[[ARANGE:.+]] = torch.aten.arange.start_step %[[ZERO]], %[[RANGELIM]], %[[ONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.float, !torch.int, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[RANGETIMESTAU:.+]] = torch.aten.mul.Scalar %[[ARANGE]], %[[TAU]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[RANGEANGULAR:.+]] = torch.aten.div.Tensor %[[RANGETIMESTAU]], %[[SIZEFP]] : !torch.vtensor<[10],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[TWORANGEANGULAR:.+]] = torch.aten.mul.Scalar %[[RANGEANGULAR]], %[[TWO]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[COSRANGEANGULAR:.+]] = torch.aten.cos %[[RANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[TWOCOSRANGEANGULAR:.+]] = torch.aten.cos %[[TWORANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[A1COMP:.+]] = torch.aten.mul.Scalar %[[COSRANGEANGULAR]], %[[A1]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[A2COMP:.+]] = torch.aten.mul.Scalar %[[TWOCOSRANGEANGULAR]], %[[A2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[RES:.+]] = torch.aten.add.Scalar %[[A1COMP]], %[[A0]], %[[ONE]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[RESULT:.+]] = torch.aten.add.Tensor %[[RES]], %[[A2COMP]], %[[ONE]] : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[INT6_1:.+]] = torch.constant.int 6 + // CHECK: %[[CAST_1:.+]] = torch.aten.to.dtype %[[RESULT]], %[[INT6_1]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f32> + // CHECK: return %[[CAST_1]] : !torch.vtensor<[10],f32> + + %0 = torch.operator "onnx.HammingWindow"(%arg0) {torch.onnx.periodic = 0 : si64} : (!torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> + return %0 : !torch.vtensor<[10],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_hammingwindow +func.func @test_hammingwindow(%arg0: !torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[A0:.+]] = torch.constant.float 5.434780e-01 + // CHECK-DAG: %[[A1:.+]] = torch.constant.float -4.565220e-01 + // CHECK-DAG: %[[A2:.+]] = torch.constant.float 0.000000e+00 + // CHECK-DAG: %[[ZERO:.+]] = torch.constant.float 0.000000e+00 + // CHECK-DAG: %[[ONE:.+]] = torch.constant.float 1.000000e+00 + // CHECK-DAG: %[[TWO:.+]] = torch.constant.float 2.000000e+00 + // CHECK-DAG: %[[TAU:.+]] = torch.constant.float 6.2831853071795862 + // CHECK-DAG: %[[NONE:.+]] = torch.constant.none + // CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false + // CHECK-DAG: %[[INT6_0:.+]] = torch.constant.int 6 + // CHECK-DAG: %[[CAST_0:.+]] = torch.aten.to.dtype %arg0, %[[INT6_0]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[SYMMSIZE:.+]] = torch.aten.sub.Scalar %[[CAST_0]], %[[ONE]], %[[ONE]] : !torch.vtensor<[],f32>, !torch.float, !torch.float -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[PERIODIC:.+]] = torch.constant.float 1.000000e+00 + // CHECK-DAG: %[[SYMMETRIC:.+]] = torch.constant.float 0.000000e+00 + // CHECK-DAG: %[[PERIODICCOMP:.+]] = torch.aten.mul.Scalar %[[CAST_0]], %[[PERIODIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[SYMMETRICCOMP:.+]] = torch.aten.mul.Scalar %[[SYMMSIZE]], %[[SYMMETRIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[SIZEFP:.+]] = torch.aten.add.Tensor %[[SYMMETRICCOMP]], %[[PERIODICCOMP]], %[[ONE]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[RANGELIM:.+]] = torch.aten.item %arg0 : !torch.vtensor<[],si32> -> !torch.int + // CHECK-DAG: %[[ARANGE:.+]] = torch.aten.arange.start_step %[[ZERO]], %[[RANGELIM]], %[[ONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.float, !torch.int, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[RANGETIMESTAU:.+]] = torch.aten.mul.Scalar %[[ARANGE]], %[[TAU]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[RANGEANGULAR:.+]] = torch.aten.div.Tensor %[[RANGETIMESTAU]], %[[SIZEFP]] : !torch.vtensor<[10],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[TWORANGEANGULAR:.+]] = torch.aten.mul.Scalar %[[RANGEANGULAR]], %[[TWO]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[COSRANGEANGULAR:.+]] = torch.aten.cos %[[RANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[TWOCOSRANGEANGULAR:.+]] = torch.aten.cos %[[TWORANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[A1COMP:.+]] = torch.aten.mul.Scalar %[[COSRANGEANGULAR]], %[[A1]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[A2COMP:.+]] = torch.aten.mul.Scalar %[[TWOCOSRANGEANGULAR]], %[[A2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[RES:.+]] = torch.aten.add.Scalar %[[A1COMP]], %[[A0]], %[[ONE]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[RESULT:.+]] = torch.aten.add.Tensor %[[RES]], %[[A2COMP]], %[[ONE]] : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[INT6_1:.+]] = torch.constant.int 6 + // CHECK: %[[CAST_1:.+]] = torch.aten.to.dtype %[[RESULT]], %[[INT6_1]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f32> + // CHECK: return %[[CAST_1]] : !torch.vtensor<[10],f32> + + %0 = torch.operator "onnx.HammingWindow"(%arg0) {torch.onnx.periodic = 1 : si64} : (!torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> + return %0 : !torch.vtensor<[10],f32> +} diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 6c6300d2e37d..1498ea257048 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -582,18 +582,10 @@ func.func @test_softmax_negative_axis(%arg0: !torch.vtensor<[3,4,5],f32>) -> !to // CHECK-LABEL: func.func @test_selu func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 6 : si64} { - // CHECK: %[[F2:.+]] = torch.constant.float 2.000000e+00 - // CHECK: %[[F3:.+]] = torch.constant.float 3.000000e+00 - // CHECK: %[[F1:.+]] = torch.constant.float 1.000000e+00 - // CHECK: %[[NONE:.+]] = torch.constant.none - // CHECK: %[[ZEROS:.+]] = torch.aten.zeros_like %arg0, %none, %none, %none, %none, %none : !torch.vtensor<[3,4,5],f32>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[EXP:.+]] = torch.aten.exp %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[MUL:.+]] = torch.aten.mul.Scalar %[[EXP]], %[[F2]] : !torch.vtensor<[3,4,5],f32>, !torch.float -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[SUB:.+]] = torch.aten.sub.Scalar %[[MUL]], %[[F2]], %[[F1]] : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.float -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[MUL_1:.+]] = torch.aten.mul.Scalar %[[SUB]], %[[F3]] : !torch.vtensor<[3,4,5],f32>, !torch.float -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[MUL_2:.+]] = torch.aten.mul.Scalar %arg0, %[[F3]] : !torch.vtensor<[3,4,5],f32>, !torch.float -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[LT:.+]] = torch.aten.lt.Tensor %arg0, %[[ZEROS]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],i1> - // CHECK: torch.aten.where.self %[[LT]], %[[MUL_1]], %[[MUL_2]] : !torch.vtensor<[3,4,5],i1>, !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK-DAG: %[[F1:.+]] = torch.constant.float 1 + // CHECK-DAG: %[[F2:.+]] = torch.constant.float 2 + // CHECK-DAG: %[[F3:.+]] = torch.constant.float 3 + // CHECK: %[[ELU:.+]] = torch.aten.elu %arg0, %[[F2]], %[[F3]], %[[F1]] %0 = torch.operator "onnx.Selu"(%arg0) {torch.onnx.alpha = 2.000000e+00 : f32, torch.onnx.gamma = 3.000000e+00 : f32} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> return %0 : !torch.vtensor<[3,4,5],f32> } @@ -868,6 +860,57 @@ func.func @test_reduce_l2_keep_dims_int_input_example(%arg0: !torch.vtensor<[3,2 // ----- +// CHECK-LABEL: func.func @test_reduce_log_sum_default_axes_keepdims_example +func.func @test_reduce_log_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> + // CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[1,1,1],f32> -> !torch.vtensor<[1,1,1],f32> + // CHECK: return %[[LOG]] : !torch.vtensor<[1,1,1],f32> + %0 = torch.operator "onnx.ReduceLogSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> + return %0 : !torch.vtensor<[1,1,1],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_log_sum_keep_dims_example +func.func @test_reduce_log_sum_keep_dims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32> + // CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2,1],f32> -> !torch.vtensor<[3,2,1],f32> + // CHECK: return %[[LOG]] : !torch.vtensor<[3,2,1],f32> + %0 = torch.operator "onnx.ReduceLogSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> + return %0 : !torch.vtensor<[3,2,1],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_log_sum_do_not_keepdims_example +func.func @test_reduce_log_sum_do_not_keepdims_example(%arg0:!torch.vtensor<[3,2,2],f32>, %arg1:!torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32> + // CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2],f32> -> !torch.vtensor<[3,2],f32> + // CHECK: return %[[LOG]] : !torch.vtensor<[3,2],f32> + %0 = torch.operator "onnx.ReduceLogSum"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> + return %0 : !torch.vtensor<[3,2],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_reduce_sum_default_axes_keepdims_example func.func @test_reduce_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.+]] = torch.constant.int 0 @@ -950,53 +993,86 @@ func.func @test_reduce_sum_negative_axes_keepdims_example(%arg0: !torch.vtensor< // ----- -// CHECK-LABEL: func.func @test_reduce_log_sum_default_axes_keepdims_example -func.func @test_reduce_log_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +// CHECK-LABEL: func.func @test_reduce_sum_square_default_axes_keepdims_example +func.func @test_reduce_sum_square_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32> // CHECK: %[[INT0:.+]] = torch.constant.int 0 // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list // CHECK: %[[TRUE:.+]] = torch.constant.bool true // CHECK: %[[NONE:.+]] = torch.constant.none - // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> - // CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[1,1,1],f32> -> !torch.vtensor<[1,1,1],f32> - // CHECK: return %[[LOG]] : !torch.vtensor<[1,1,1],f32> - %0 = torch.operator "onnx.ReduceLogSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> + // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> + // CHECK: return %[[SUM]] : !torch.vtensor<[1,1,1],f32> + %0 = torch.operator "onnx.ReduceSumSquare"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> return %0 : !torch.vtensor<[1,1,1],f32> } // ----- -// CHECK-LABEL: func.func @test_reduce_log_sum_keep_dims_example -func.func @test_reduce_log_sum_keep_dims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +// CHECK-LABEL: func.func @test_reduce_sum_square_do_not_keepdims_example +func.func @test_reduce_sum_square_do_not_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32> // CHECK: %[[INT0:.+]] = torch.constant.int 0 // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32> + // CHECK: return %[[SUM]] : !torch.vtensor<[3,2],f32> + %0 = torch.operator "onnx.ReduceSumSquare"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> + return %0 : !torch.vtensor<[3,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_sum_square_empty_set_non_reduced_axis_zero +func.func @test_reduce_sum_square_empty_set_non_reduced_axis_zero(%arg0: !torch.vtensor<[2,0,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,0,1],f32> attributes {torch.onnx_meta.ir_version = 8: si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[2,0,4],f32>, !torch.vtensor<[2,0,4],f32> -> !torch.vtensor<[2,0,4],f32> + // CHECK: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[INT2]] : (!torch.int) -> !torch.list // CHECK: %[[TRUE:.+]] = torch.constant.bool true // CHECK: %[[NONE:.+]] = torch.constant.none - // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32> - // CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2,1],f32> -> !torch.vtensor<[3,2,1],f32> - // CHECK: return %[[LOG]] : !torch.vtensor<[3,2,1],f32> - %0 = torch.operator "onnx.ReduceLogSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> - return %0 : !torch.vtensor<[3,2,1],f32> + // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[2,0,4],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[2,0,1],f32> + // CHECK: return %[[SUM]] : !torch.vtensor<[2,0,1],f32> + %0 = torch.operator "onnx.ReduceSumSquare"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,0,4],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,0,1],f32> + return %0 : !torch.vtensor<[2,0,1],f32> } // ----- -// CHECK-LABEL: func.func @test_reduce_log_sum_do_not_keepdims_example -func.func @test_reduce_log_sum_do_not_keepdims_example(%arg0:!torch.vtensor<[3,2,2],f32>, %arg1:!torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +// CHECK-LABEL: func.func @test_reduce_sum_square_keepdims_example +func.func @test_reduce_sum_square_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32> // CHECK: %[[INT0:.+]] = torch.constant.int 0 // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list - // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[TRUE:.+]] = torch.constant.bool true // CHECK: %[[NONE:.+]] = torch.constant.none - // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32> - // CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2],f32> -> !torch.vtensor<[3,2],f32> - // CHECK: return %[[LOG]] : !torch.vtensor<[3,2],f32> - %0 = torch.operator "onnx.ReduceLogSum"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> - return %0 : !torch.vtensor<[3,2],f32> + // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2],f32> + // CHECK: return %[[SUM]] : !torch.vtensor<[3,1,2],f32> + %0 = torch.operator "onnx.ReduceSumSquare"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> + return %0 : !torch.vtensor<[3,1,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_sum_square_keepdims_int_example +func.func @test_reduce_sum_square_keepdims_int_example(%arg0: !torch.vtensor<[3,2,2],si64>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],si64>, !torch.vtensor<[3,2,2],si64> -> !torch.vtensor<[3,2,2],si64> + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],si64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2],f32> + // CHECK: return %[[SUM]] : !torch.vtensor<[3,1,2],f32> + %0 = torch.operator "onnx.ReduceSumSquare"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> + return %0 : !torch.vtensor<[3,1,2],f32> } // ----- @@ -1265,15 +1341,9 @@ func.func @test_reduce_prod_keepdims_random(%arg0: !torch.vtensor<[3,2,2],f32>, // ----- -// CHECK-LABEL: func.func @test_sinh_example +// CHECK-LABEL: func.func @test_sinh func.func @test_sinh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 9 : si64} { - // CHECK: %[[X:.+]] = torch.aten.exp %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> - // CHECK: %[[NEG:.+]] = torch.aten.neg %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> - // CHECK: %[[Y:.+]] = torch.aten.exp %[[NEG]] : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> - // CHECK: %[[C1:.+]] = torch.constant.int 1 - // CHECK: %[[SUB:.+]] = torch.aten.sub.Tensor %[[X]], %[[Y]], %[[C1]] : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> - // CHECK: %[[C2:.+]] = torch.constant.int 2 - // CHECK: torch.aten.div.Scalar %[[SUB]], %[[C2]] : !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> + // CHECK: torch.aten.sinh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> %0 = torch.operator "onnx.Sinh"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> return %0 : !torch.vtensor<[3],f32> } @@ -1939,3 +2009,34 @@ func.func @test_random_uniform_like(%arg0: !torch.vtensor<[10],f32>) -> !torch.v %0 = torch.operator "onnx.RandomUniformLike"(%arg0) {torch.onnx.dtype = 1 : si64, torch.onnx.high = 1.000000e+00 : f32, torch.onnx.low = 0.000000e+00 : f32} : (!torch.vtensor<[10],f32>) -> !torch.vtensor<[10],f32> return %0 : !torch.vtensor<[10],f32> } + +// ----- + +// CHECK-LABEL: func.func @test_sce_mean_3d +func.func @test_sce_mean_3d(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1: !torch.vtensor<[3,2],si64>) -> !torch.vtensor<[],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[IGNORE_INDEX:.+]] = torch.constant.int -100 + // CHECK: %[[REDUCTION:.+]] = torch.constant.int 1 + // CHECK: %[[LABEL_SMOOTHING:.+]] = torch.constant.float 0.000000e+00 + // CHECK: %[[LOSS:.+]] = torch.aten.cross_entropy_loss %arg0, %arg1, %[[NONE]], %[[REDUCTION]], %[[IGNORE_INDEX:.+]], %[[LABEL_SMOOTHING]] : !torch.vtensor<[3,5,2],f32>, !torch.vtensor<[3,2],si64>, !torch.none, !torch.int, !torch.int, !torch.float -> !torch.vtensor<[],f32> + // CHECK: return %[[LOSS]] : !torch.vtensor<[],f32> + %0 = torch.operator "onnx.SoftmaxCrossEntropyLoss"(%arg0, %arg1) {torch.onnx.reduction = "mean"} : (!torch.vtensor<[3,5,2],f32>, !torch.vtensor<[3,2],si64>) -> !torch.vtensor<[],f32> + return %0 : !torch.vtensor<[],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_sce_mean_3d_log_prob +func.func @test_sce_mean_3d_log_prob(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1: !torch.vtensor<[3,2],si64>) -> (!torch.vtensor<[],f32>, !torch.vtensor<[3,5,2],f32>) attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[IGNORE_INDEX:.+]] = torch.constant.int -100 + // CHECK: %[[REDUCTION:.+]] = torch.constant.int 1 + // CHECK: %[[LABEL_SMOOTHING:.+]] = torch.constant.float 0.000000e+00 + // CHECK: %[[LOSS:.+]] = torch.aten.cross_entropy_loss %arg0, %arg1, %[[NONE]], %[[REDUCTION]], %[[IGNORE_INDEX:.+]], %[[LABEL_SMOOTHING]] : !torch.vtensor<[3,5,2],f32>, !torch.vtensor<[3,2],si64>, !torch.none, !torch.int, !torch.int, !torch.float -> !torch.vtensor<[],f32> + // CHECK: %[[DIM:.+]] = torch.constant.int 1 + // CHECK: %[[NONE_0:.+]] = torch.constant.none + // CHECK: %[[PROB:.+]] = torch.aten.log_softmax.int %arg0, %[[DIM]], %[[NONE_0]] : !torch.vtensor<[3,5,2],f32>, !torch.int, !torch.none -> !torch.vtensor<[3,5,2],f32> + // CHECK: return %[[LOSS]], %[[PROB]] : !torch.vtensor<[],f32>, !torch.vtensor<[3,5,2],f32> + %0:2 = torch.operator "onnx.SoftmaxCrossEntropyLoss"(%arg0, %arg1) {torch.onnx.reduction = "mean"} : (!torch.vtensor<[3,5,2],f32>, !torch.vtensor<[3,2],si64>) -> (!torch.vtensor<[],f32>, !torch.vtensor<[3,5,2],f32>) + return %0#0, %0#1 : !torch.vtensor<[],f32>, !torch.vtensor<[3,5,2],f32> +} diff --git a/test/Conversion/TorchToStablehlo/basic.mlir b/test/Conversion/TorchToStablehlo/basic.mlir index 92888616a67b..d8ec0fa6495f 100644 --- a/test/Conversion/TorchToStablehlo/basic.mlir +++ b/test/Conversion/TorchToStablehlo/basic.mlir @@ -55,7 +55,7 @@ func.func @torch.aten.contiguous(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vt // CHECK-LABEL: func.func @torch.aten.reciprocal( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?],f32> -> tensor -// CHECK: %[[VAL_2:.*]] = "chlo.constant_like"(%[[VAL_1]]) {value = 1.000000e+00 : f32} : (tensor) -> tensor +// CHECK: %[[VAL_2:.*]] = "chlo.constant_like"(%[[VAL_1]]) <{value = 1.000000e+00 : f32}> : (tensor) -> tensor // CHECK: %[[VAL_3:.*]] = stablehlo.divide %[[VAL_2]], %[[VAL_1]] : tensor // CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor -> !torch.vtensor<[?,?,?],f32> // CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?,?],f32> @@ -124,7 +124,7 @@ func.func @torch.aten.broadcast_to$dynamic_implicit(%arg0: !torch.vtensor<[?,?], // CHECK: %[[VAL_4:.*]] = arith.constant 1 : index // CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_1]], %[[VAL_4]] : tensor // CHECK: %[[VAL_6:.*]] = tensor.from_elements %[[VAL_5]] : tensor<1xindex> -// CHECK: %[[VAL_7:.*]], %[[VAL_8:.*]], %[[VAL_9:.*]] = "stablehlo.batch_norm_training"(%[[VAL_1]], %[[VAL_3]], %[[VAL_2]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor, tensor<3xf32>, tensor<3xf32>) -> (tensor, tensor<3xf32>, tensor<3xf32>) +// CHECK: %[[VAL_7:.*]], %[[VAL_8:.*]], %[[VAL_9:.*]] = "stablehlo.batch_norm_training"(%[[VAL_1]], %[[VAL_3]], %[[VAL_2]]) <{epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64}> : (tensor, tensor<3xf32>, tensor<3xf32>) -> (tensor, tensor<3xf32>, tensor<3xf32>) // CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,3,?,?],f32> // CHECK: return %[[VAL_8]] : !torch.vtensor<[?,3,?,?],f32> func.func @torch.aten.batch_norm$training(%arg0: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> { @@ -152,7 +152,7 @@ func.func @torch.aten.batch_norm$training(%arg0: !torch.vtensor<[?,3,?,?],f32>) // CHECK: %[[T3:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor // CHECK: %[[T4:.*]] = tensor.from_elements %[[T3]] : tensor<1xindex> // CHECK: %[[T5:.*]] = tensor.cast %[[T0]] : tensor to tensor -// CHECK: %[[T6:.*]] = "stablehlo.batch_norm_inference"(%[[T5]], %[[T2]], %[[T1]], %[[T1]], %[[T2]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor +// CHECK: %[[T6:.*]] = "stablehlo.batch_norm_inference"(%[[T5]], %[[T2]], %[[T1]], %[[T1]], %[[T2]]) <{epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64}> : (tensor, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor // CHECK: %[[T7:.*]] = tensor.cast %[[T6]] : tensor to tensor // CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor -> !torch.vtensor<[?,3,?,?],f32> // CHECK: return %[[T8]] : !torch.vtensor<[?,3,?,?],f32> @@ -185,7 +185,7 @@ func.func @torch.aten.batch_norm$inference(%arg0: !torch.vtensor<[?,3,?,?],f32>) // CHECK: %[[VAL_8:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_7]], %[[VAL_6]], dims = [] : (tensor, tensor<1xindex>) -> tensor<3xf32> // CHECK: %[[VAL_9:.*]] = stablehlo.constant dense<0.000000e+00> : tensor // CHECK: %[[VAL_10:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_9]], %[[VAL_6]], dims = [] : (tensor, tensor<1xindex>) -> tensor<3xf32> -// CHECK: %[[VAL_11:.*]], %[[VAL_12:.*]], %[[VAL_13:.*]] = "stablehlo.batch_norm_training"(%[[VAL_1]], %[[VAL_8]], %[[VAL_10]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor, tensor<3xf32>, tensor<3xf32>) -> (tensor, tensor<3xf32>, tensor<3xf32>) +// CHECK: %[[VAL_11:.*]], %[[VAL_12:.*]], %[[VAL_13:.*]] = "stablehlo.batch_norm_training"(%[[VAL_1]], %[[VAL_8]], %[[VAL_10]]) <{epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64}> : (tensor, tensor<3xf32>, tensor<3xf32>) -> (tensor, tensor<3xf32>, tensor<3xf32>) // CHECK: %[[VAL_14:.*]] = torch_c.from_builtin_tensor %[[VAL_11]] : tensor -> !torch.vtensor<[?,3,?,?],f32> // CHECK: return %[[VAL_14]] : !torch.vtensor<[?,3,?,?],f32> func.func @torch.aten.batch_norm$no_bias_weight(%arg0: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> { @@ -214,7 +214,7 @@ func.func @torch.aten.batch_norm$no_bias_weight(%arg0: !torch.vtensor<[?,3,?,?], // CHECK: %[[VAL_6:.*]] = stablehlo.dynamic_reshape %[[VAL_1]], %[[VAL_5]] : (tensor<3x7x4x5xf32>, tensor<3xi64>) -> tensor<1x21x20xf32> // CHECK: %[[VAL_7:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<21xf32> // CHECK: %[[VAL_8:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<21xf32> -// CHECK: %[[VAL_9:.*]], %[[VAL_10:.*]], %[[VAL_11:.*]] = "stablehlo.batch_norm_training"(%[[VAL_6]], %[[VAL_7]], %[[VAL_8]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>) -> (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>) +// CHECK: %[[VAL_9:.*]], %[[VAL_10:.*]], %[[VAL_11:.*]] = "stablehlo.batch_norm_training"(%[[VAL_6]], %[[VAL_7]], %[[VAL_8]]) <{epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64}> : (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>) -> (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>) // CHECK: %[[VAL_12:.*]] = stablehlo.constant dense<[3, 7, 4, 5]> : tensor<4xi64> // CHECK: %[[VAL_13:.*]] = stablehlo.dynamic_reshape %[[VAL_9]], %[[VAL_12]] : (tensor<1x21x20xf32>, tensor<4xi64>) -> tensor<3x7x4x5xf32> // CHECK: %[[VAL_14:.*]] = stablehlo.constant dense<[3, 7, 1, 1]> : tensor<4xi64> diff --git a/test/Conversion/TorchToStablehlo/elementwise.mlir b/test/Conversion/TorchToStablehlo/elementwise.mlir index 3ff1d095c532..329fc1b96cd7 100644 --- a/test/Conversion/TorchToStablehlo/elementwise.mlir +++ b/test/Conversion/TorchToStablehlo/elementwise.mlir @@ -4,9 +4,9 @@ // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[STR:.*]] = torch.constant.str "none" -// CHECK: %[[T1:.*]] = "chlo.constant_like"(%[[T0]]) {value = 1.000000e+00 : f32} : (tensor) -> tensor -// CHECK: %[[T2:.*]] = "chlo.constant_like"(%[[T0]]) {value = 2.000000e+00 : f32} : (tensor) -> tensor -// CHECK: %[[T3:.*]] = "chlo.constant_like"(%[[T0]]) {value = 5.000000e-01 : f32} : (tensor) -> tensor +// CHECK: %[[T1:.*]] = "chlo.constant_like"(%[[T0]]) <{value = 1.000000e+00 : f32}> : (tensor) -> tensor +// CHECK: %[[T2:.*]] = "chlo.constant_like"(%[[T0]]) <{value = 2.000000e+00 : f32}> : (tensor) -> tensor +// CHECK: %[[T3:.*]] = "chlo.constant_like"(%[[T0]]) <{value = 5.000000e-01 : f32}> : (tensor) -> tensor // CHECK: %[[T4:.*]] = stablehlo.rsqrt %[[T2]] : tensor // CHECK: %[[T5:.*]] = stablehlo.multiply %[[T0]], %[[T4]] : tensor // CHECK: %[[T6:.*]] = chlo.erf %[[T5]] : tensor -> tensor @@ -475,7 +475,7 @@ func.func @torch.aten.permute$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch // CHECK-LABEL: func.func @torch.aten.relu( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[T1:.*]] = "chlo.constant_like"(%[[T0]]) {value = 0.000000e+00 : f32} : (tensor) -> tensor +// CHECK: %[[T1:.*]] = "chlo.constant_like"(%[[T0]]) <{value = 0.000000e+00 : f32}> : (tensor) -> tensor // CHECK: %[[T2:.*]] = stablehlo.maximum %[[T0]], %[[T1]] : tensor // CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T3]] : !torch.vtensor<[?,?],f32> diff --git a/test/Conversion/TorchToStablehlo/gather.mlir b/test/Conversion/TorchToStablehlo/gather.mlir index a88b6e375071..df29bf1d4cca 100644 --- a/test/Conversion/TorchToStablehlo/gather.mlir +++ b/test/Conversion/TorchToStablehlo/gather.mlir @@ -10,7 +10,7 @@ // CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor // CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 // CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64> -// CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #stablehlo.gather, indices_are_sorted = false} : (tensor, tensor<2xi64>, tensor<2xi64>) -> tensor<2x4xf32> +// CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false}> : (tensor, tensor<2xi64>, tensor<2xi64>) -> tensor<2x4xf32> // CHECK: %[[T6:.*]] = stablehlo.convert %[[T5]] : tensor<2x4xf32> // CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32> // CHECK: return %[[T7]] : !torch.vtensor<[2,4],f32> @@ -31,7 +31,7 @@ func.func @torch.aten.index_select$basic(%arg0: !torch.vtensor<[?,4],f32>, %arg1 // CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor // CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 // CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64> -// CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #stablehlo.gather, indices_are_sorted = false} : (tensor, tensor, tensor<2xi64>) -> tensor +// CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false}> : (tensor, tensor, tensor<2xi64>) -> tensor // CHECK: %[[T6:.*]] = stablehlo.convert %[[T5]] : tensor // CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T7]] : !torch.vtensor<[?,?],f32> @@ -53,7 +53,7 @@ func.func @torch.aten.embedding$basic(%weight: !torch.vtensor<[?,?],f32>, %indic // CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor // CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 // CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64> -// CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #stablehlo.gather, indices_are_sorted = false} : (tensor, tensor, tensor<2xi64>) -> tensor +// CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false}> : (tensor, tensor, tensor<2xi64>) -> tensor // CHECK: %[[T6:.*]] = stablehlo.convert %[[T5]] : tensor // CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor -> !torch.vtensor<[?,1,?],f32> // CHECK: return %[[T7]] : !torch.vtensor<[?,1,?],f32> diff --git a/test/Conversion/TorchToStablehlo/pooling.mlir b/test/Conversion/TorchToStablehlo/pooling.mlir index b8fc6cbd8384..156c3ff51be2 100644 --- a/test/Conversion/TorchToStablehlo/pooling.mlir +++ b/test/Conversion/TorchToStablehlo/pooling.mlir @@ -14,11 +14,11 @@ // CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_6:.*]] = stablehlo.constant dense<0xFF800000> : tensor -// CHECK: %[[VAL_7:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_6]]) ({ +// CHECK: %[[VAL_7:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_6]]) <{padding = dense<0> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array}> ({ // CHECK: ^bb0(%[[VAL_8:.*]]: tensor, %[[VAL_9:.*]]: tensor): // CHECK: %[[VAL_10:.*]] = stablehlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor // CHECK: stablehlo.return %[[VAL_10]] : tensor -// CHECK: }) {padding = dense<0> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor, tensor) -> tensor +// CHECK: }) : (tensor, tensor) -> tensor // CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,?,?,?],f32> // CHECK: return %[[VAL_11]] : !torch.vtensor<[?,?,?,?],f32> func.func @torch.aten.max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { @@ -46,12 +46,12 @@ func.func @torch.aten.max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch // CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_5:.*]] = stablehlo.constant dense<0xFF800000> : tensor -// CHECK: %[[VAL_6:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({ +// CHECK: %[[VAL_6:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) +// CHECK{LITERAL}: <{padding = dense<[[0, 0], [0, 0], [2, 2], [1, 1]]> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array}> ({ // CHECK: ^bb0(%[[VAL_8:.*]]: tensor, %[[VAL_9:.*]]: tensor): // CHECK: %[[VAL_10:.*]] = stablehlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor // CHECK: stablehlo.return %[[VAL_10]] : tensor -// CHECK: }) -// CHECK-SAME{LITERAL}: {padding = dense<[[0, 0], [0, 0], [2, 2], [1, 1]]> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor, tensor) -> tensor +// CHECK: }) : (tensor, tensor) -> tensor // CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor -> !torch.vtensor<[?,?,?,?],f32> // CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?,?,?],f32> func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { @@ -96,7 +96,7 @@ func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) - // CHECK: %[[T10:.*]] = stablehlo.dynamic_iota %[[FROM_ELEMENTS_2]], dim = 1 : (tensor<2xi64>) -> tensor // CHECK: %[[T11:.*]] = stablehlo.dynamic_reshape %[[T10]], %[[FROM_ELEMENTS]] : (tensor, tensor<3xi64>) -> tensor // CHECK: %[[T12:.*]] = stablehlo.constant dense<0> : tensor -// CHECK: %[[T13:.*]]:2 = "stablehlo.reduce_window"(%[[T0]], %[[T11]], %[[T5]], %[[T12]]) ({ +// CHECK: %[[T13:.*]]:2 = "stablehlo.reduce_window"(%[[T0]], %[[T11]], %[[T5]], %[[T12]]) <{padding = dense<0> : tensor<3x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array}> ({ // CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor, %[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): // CHECK: %[[T16:.*]] = stablehlo.compare GE, %[[ARG1]], %[[ARG3]], FLOAT : (tensor, tensor) -> tensor // CHECK: %[[T17:.*]] = stablehlo.select %[[T16]], %[[ARG1]], %[[ARG3]] : tensor, tensor @@ -105,7 +105,7 @@ func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) - // CHECK: %[[T20:.*]] = stablehlo.select %[[T16]], %[[ARG2]], %[[ARG4]] : tensor, tensor // CHECK: %[[T21:.*]] = stablehlo.select %[[T18]], %[[T19]], %[[T20]] : tensor, tensor // CHECK: stablehlo.return %[[T17]], %[[T21]] : tensor, tensor -// CHECK: }) {padding = dense<0> : tensor<3x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor, tensor, tensor, tensor) -> (tensor, tensor) +// CHECK: }) : (tensor, tensor, tensor, tensor) -> (tensor, tensor) // CHECK: %[[T14:.*]] = torch_c.from_builtin_tensor %[[T13]]#0 : tensor -> !torch.vtensor<[?,?,?],f32> // CHECK: %[[T15:.*]] = torch_c.from_builtin_tensor %[[T13]]#1 : tensor -> !torch.vtensor<[?,?,?],si64> // CHECK: return %[[T14]], %[[T15]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],si64> @@ -137,11 +137,12 @@ func.func @torch.aten.max_pool2d_with_indices(%arg0: !torch.vtensor<[?,?,?],f32> // CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_5:.*]] = stablehlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[VAL_6:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({ +// CHECK: %[[VAL_6:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) +// CHECK{LITERAL}: <{padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array}> ({ // CHECK: ^bb0(%[[IVAL_0:.*]]: tensor, %[[IVAL_1:.*]]: tensor): // CHECK: %[[IVAL_2:.*]] = stablehlo.add %[[IVAL_0]], %[[IVAL_1]] : tensor // CHECK: stablehlo.return %[[IVAL_2]] : tensor -// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor, tensor) -> tensor +// CHECK: }) : (tensor, tensor) -> tensor // CHECK: %[[VAL_7:.*]] = stablehlo.constant dense<1.000000e+00> : tensor // CHECK: %[[IDX_0:.*]] = arith.constant 0 : index // CHECK: %[[VAL_8:.*]] = tensor.dim %[[VAL_1]], %[[IDX_0]] : tensor @@ -158,11 +159,12 @@ func.func @torch.aten.max_pool2d_with_indices(%arg0: !torch.vtensor<[?,?,?],f32> // CHECK: %[[VAL_16:.*]] = tensor.from_elements %[[VAL_9]], %[[VAL_11]], %[[VAL_13]], %[[VAL_15]] : tensor<4xi64> // CHECK: %[[VAL_17:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_7]], %[[VAL_16]], dims = [] : (tensor, tensor<4xi64>) -> tensor // CHECK: %[[VAL_18:.*]] = stablehlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[VAL_19:.*]] = "stablehlo.reduce_window"(%[[VAL_17]], %[[VAL_18]]) ({ +// CHECK: %[[VAL_19:.*]] = "stablehlo.reduce_window"(%[[VAL_17]], %[[VAL_18]]) +// CHECK{LITERAL}: <{padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array}> ({ // CHECK: ^bb0(%[[IVAL_3:.*]]: tensor, %[[IVAL_4:.*]]: tensor): // CHECK: %[[IVAL_5:.*]] = stablehlo.add %[[IVAL_3]], %[[IVAL_4]] : tensor // CHECK: stablehlo.return %[[IVAL_5]] : tensor -// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor, tensor) -> tensor +// CHECK: }) : (tensor, tensor) -> tensor // CHECK: %[[VAL_20:.*]] = stablehlo.divide %[[VAL_6]], %[[VAL_19]] : tensor // CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor -> !torch.vtensor<[?,?,?,?],f32> // CHECK: return %[[VAL_21]] : !torch.vtensor<[?,?,?,?],f32> @@ -194,11 +196,12 @@ func.func @torch.aten.avg_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch // CHECK: %[[T2:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T3:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T4:.*]] = stablehlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[T5:.*]] = "stablehlo.reduce_window"(%[[T0]], %[[T4]]) ({ +// CHECK: %[[T5:.*]] = "stablehlo.reduce_window"(%[[T0]], %[[T4]]) +// CHECK{LITERAL}: <{padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array}> ({ // CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): // CHECK: %[[T10:.*]] = stablehlo.add %[[ARG1]], %[[ARG2]] : tensor // CHECK: stablehlo.return %[[T10]] : tensor -// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor, tensor) -> tensor +// CHECK: }) : (tensor, tensor) -> tensor // CHECK: %[[T6:.*]] = stablehlo.constant dense<9> : tensor // CHECK: %[[T7:.*]] = stablehlo.convert %[[T6]] : (tensor) -> tensor // CHECK: %[[T8:.*]] = chlo.broadcast_divide %[[T5]], %[[T7]] : (tensor, tensor) -> tensor diff --git a/test/Conversion/TorchToStablehlo/scatter.mlir b/test/Conversion/TorchToStablehlo/scatter.mlir index 432fc0c86c5f..fe8ffb9ee205 100644 --- a/test/Conversion/TorchToStablehlo/scatter.mlir +++ b/test/Conversion/TorchToStablehlo/scatter.mlir @@ -22,10 +22,10 @@ // CHECK: %[[VAR_6:.*]] = stablehlo.dynamic_reshape %1, %[[FE_3]] : (tensor, tensor<3xi64>) -> tensor // CHECK: %[[VAR_7:.*]] = stablehlo.dynamic_iota %[[FE_3]], dim = 1 : (tensor<3xi64>) -> tensor // CHECK: %[[VAR_8:.*]] = stablehlo.concatenate %[[VAR_6]], %[[VAR_7]], dim = 2 : (tensor, tensor) -> tensor -// CHECK: %[[VAR_9:.*]] = "stablehlo.scatter"(%[[VAR_0]], %[[VAR_8]], %[[VAR_5]]) ({ +// CHECK: %[[VAR_9:.*]] = "stablehlo.scatter"(%[[VAR_0]], %[[VAR_8]], %[[VAR_5]]) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ // CHECK: ^bb0(%arg3: tensor, %[[ARG_4:.*]]: tensor): // CHECK: stablehlo.return %[[ARG_4]] : tensor -// CHECK: }) {indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false} : (tensor, tensor, tensor) -> tensor +// CHECK: }) : (tensor, tensor, tensor) -> tensor // CHECK: %[[VAR_10:.*]] = torch_c.from_builtin_tensor %[[VAR_9]] : tensor -> !torch.vtensor<[?,?],si64> // CHECK: return %[[VAR_10]] : !torch.vtensor<[?,?],si64> func.func @forward(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> {