diff --git a/include/torch-mlir-c/TorchTypes.h b/include/torch-mlir-c/TorchTypes.h index b214e147d5d9..dd7cfb5c428f 100644 --- a/include/torch-mlir-c/TorchTypes.h +++ b/include/torch-mlir-c/TorchTypes.h @@ -220,6 +220,19 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchQUInt8TypeGet(MlirContext context); /// Gets the !torch.quint8 typeid. MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchQUInt8TypeGetTypeID(void); +//===----------------------------------------------------------------------===// +// torch.qint16 type. +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is a !torch.qint16 type +MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchQInt16(MlirType t); + +/// Gets the !torch.qint16 type. +MLIR_CAPI_EXPORTED MlirType torchMlirTorchQInt16TypeGet(MlirContext context); + +/// Gets the !torch.qint16 typeid. +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchQInt16TypeGetTypeID(void); + //===----------------------------------------------------------------------===// // torch.tensor type. //===----------------------------------------------------------------------===// diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h index 0de85f4eebe5..f296b6dfee5c 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h @@ -110,6 +110,18 @@ struct OpBinder { return success(); } + ParseResult tensorListOperandAtIndex(Value &valueIdx, int64_t idx) { + if (idx >= op->getNumOperands()) + return failure(); + valueIdx = op->getOperand(idx); + auto tt = dyn_cast(valueIdx.getType()); + if (!tt) + return failure(); + if (!toValidTensorType(tt.getContainedType())) + return failure(); + return success(); + } + ParseResult tensorListResultType(Torch::ListType &type0) { if (op->getNumResults() != 1) return failure(); diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h index d8d2534f9a0c..4bf6c845c68a 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h @@ -36,7 +36,7 @@ Value createConstantIntList(OpBinder binder, ConversionPatternRewriter &rewriter, SmallVector cstInput); -Type getQTorchTypeFromTorchIntType(Type ty); +Torch::ValueTensorType getQTorchTypeFromTorchIntType(Type ty); template Value getItemOp(OpBinder binder, ConversionPatternRewriter &rewriter, diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td index 279e694540f9..367b08610cd8 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td @@ -315,6 +315,16 @@ def Torch_QInt8Type : Torch_Type<"QInt8", "qint8"> { }]; } +def Torch_QInt16Type : Torch_Type<"QInt16", "qint16"> { + let summary = "Type modeling `ScalarType::QInt16`, which doesn't yet exist"; + let description = [{ + Pytorch does not have 16-bit integer quantization support. + + This torch type is added to provide a target for 16-bit quantization + schemes coming from imported onnx models. + }]; +} + def Torch_QUInt8Type : Torch_Type<"QUInt8", "quint8"> { let summary = "Type modeling `ScalarType::QUInt8`"; let description = [{ diff --git a/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h b/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h index 043dd92549b2..e2b57538d7e6 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h +++ b/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h @@ -86,24 +86,34 @@ enum class TypeKind { // at:: and c10:: parts of the macro are never used within the compiler -- we // only use this for the enum values. #define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(_) \ - _(uint8_t, Byte) /* 0 */ \ - _(int8_t, Char) /* 1 */ \ - _(int16_t, Short) /* 2 */ \ - _(int, Int) /* 3 */ \ - _(int64_t, Long) /* 4 */ \ - _(at::Half, Half) /* 5 */ \ - _(float, Float) /* 6 */ \ - _(double, Double) /* 7 */ \ - _(c10::complex, ComplexHalf) /* 8 */ \ - _(c10::complex, ComplexFloat) /* 9 */ \ - _(c10::complex, ComplexDouble) /* 10 */ \ - _(bool, Bool) /* 11 */ \ - _(c10::qint8, QInt8) /* 12 */ \ - _(c10::quint8, QUInt8) /* 13 */ \ - _(c10::qint32, QInt32) /* 14 */ \ - _(at::BFloat16, BFloat16) /* 15 */ \ - _(c10::quint4x2, QUInt4x2) /* 16 */ \ - _(c10::quint2x4, QUInt2x4) /* 17 */ + _(uint8_t, Byte) /* 0 */ \ + _(int8_t, Char) /* 1 */ \ + _(int16_t, Short) /* 2 */ \ + _(int, Int) /* 3 */ \ + _(int64_t, Long) /* 4 */ \ + _(at::Half, Half) /* 5 */ \ + _(float, Float) /* 6 */ \ + _(double, Double) /* 7 */ \ + _(c10::complex, ComplexHalf) /* 8 */ \ + _(c10::complex, ComplexFloat) /* 9 */ \ + _(c10::complex, ComplexDouble) /* 10 */ \ + _(bool, Bool) /* 11 */ \ + _(c10::qint8, QInt8) /* 12 */ \ + _(c10::quint8, QUInt8) /* 13 */ \ + _(c10::qint32, QInt32) /* 14 */ \ + _(at::BFloat16, BFloat16) /* 15 */ \ + _(c10::quint4x2, QUInt4x2) /* 16 */ \ + _(c10::quint2x4, QUInt2x4) /* 17 */ \ + _(c10::bits1x8, Bits1x8) /* 18 */ \ + _(c10::bits2x4, Bits2x4) /* 19 */ \ + _(c10::bits4x2, Bits4x2) /* 20 */ \ + _(c10::bits8, Bits8) /* 21 */ \ + _(c10::bits16, Bits16) /* 22 */ \ + _(c10::Float8_e5m2, Float8_e5m2) /* 23 */ \ + _(c10::Float8_e4m3fn, Float8_e4m3fn) /* 24 */ \ + _(c10::Float8_e5m2fnuz, Float8_e5m2fnuz) /* 25 */ \ + _(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) /* 26 */ \ + _(c10::qint16, QInt16) /* 27 */ enum class ScalarType : int8_t { #define DEFINE_ENUM(_1, n) n, diff --git a/lib/CAPI/TorchTypes.cpp b/lib/CAPI/TorchTypes.cpp index 399915459e40..6402e44a3701 100644 --- a/lib/CAPI/TorchTypes.cpp +++ b/lib/CAPI/TorchTypes.cpp @@ -269,6 +269,22 @@ MlirTypeID torchMlirTorchQUInt8TypeGetTypeID() { return wrap(Torch::QUInt8Type::getTypeID()); } +//===----------------------------------------------------------------------===// +// torch.qint16 type. +//===----------------------------------------------------------------------===// + +bool torchMlirTypeIsATorchQInt16(MlirType t) { + return isa(unwrap(t)); +} + +MlirType torchMlirTorchQInt16TypeGet(MlirContext context) { + return wrap(Torch::QInt16Type::get(unwrap(context))); +} + +MlirTypeID torchMlirTorchQInt16TypeGetTypeID() { + return wrap(Torch::QInt16Type::getTypeID()); +} + //===----------------------------------------------------------------------===// // torch.tensor type. //===----------------------------------------------------------------------===// diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index b26e1ea3a5f1..d0ff6e973a7e 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -737,7 +737,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( std::numeric_limits::lowest())) return failure(); auto minSplatAttr = SplatElementsAttr::get( - resultType.toBuiltinTensor().clone(resultDtype), + resultType.toBuiltinTensor(), rewriter.getFloatAttr(resultDtype, minValue)); min = rewriter.create( binder.getLoc(), resultType, minSplatAttr); @@ -748,7 +748,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( std::numeric_limits::max())) return failure(); auto maxSplatAttr = SplatElementsAttr::get( - resultType.toBuiltinTensor().clone(resultDtype), + resultType.toBuiltinTensor(), rewriter.getFloatAttr(resultDtype, maxValue)); max = rewriter.create( binder.getLoc(), resultType, maxSplatAttr); @@ -829,7 +829,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( return success(); }); patterns.onOp( - "Concat", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + "Concat", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; SmallVector tensors; int64_t dim; @@ -861,7 +861,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( if (binder.op->hasAttr("torch.onnx.value_float") && !binder.f32FloatAttr(floatValue, "value_float", 0.0)) { auto splatAttr = - SplatElementsAttr::get(resultType.toBuiltinTensor().clone(dtype), + SplatElementsAttr::get(resultType.toBuiltinTensor(), rewriter.getFloatAttr(dtype, floatValue)); rewriter.replaceOpWithNewOp( binder.op, resultType, splatAttr); @@ -872,7 +872,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( if (binder.op->hasAttr("torch.onnx.value_int") && !binder.s64IntegerAttr(intValue, "value_int", 0)) { auto splatAttr = - SplatElementsAttr::get(resultType.toBuiltinTensor().clone(dtype), + SplatElementsAttr::get(resultType.toBuiltinTensor(), rewriter.getIntegerAttr(dtype, intValue)); rewriter.replaceOpWithNewOp( binder.op, resultType, splatAttr); @@ -932,8 +932,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( for (auto intVal : intValues) { apValues.push_back(APInt(dtype.getIntOrFloatBitWidth(), intVal)); } - auto attr = DenseElementsAttr::get( - resultType.toBuiltinTensor().clone(dtype), apValues); + auto attr = + DenseElementsAttr::get(resultType.toBuiltinTensor(), apValues); rewriter.replaceOpWithNewOp( binder.op, resultType, attr); return success(); @@ -1715,21 +1715,12 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( "requires known result dtype"); if (scaleTy.getSizes().size() == 0 || (scaleTy.getSizes().size() == 1 && scaleTy.getSizes()[0] == 1)) { - Type qTy = operandTy.getDtype(); - - if (qTy.isUnsignedInteger(8)) { - qTy = rewriter.getType(); - } else if (qTy.isSignedInteger(8)) { - qTy = rewriter.getType(); - } else if (qTy.isSignedInteger(32)) { - qTy = rewriter.getType(); - } else { + auto qTensorTy = getQTorchTypeFromTorchIntType(operandTy); + if (!qTensorTy) { return rewriter.notifyMatchFailure(binder.op, "unsupported result dtype"); } - auto qTensorTy = rewriter.getType( - resultType.getOptionalSizes(), qTy); scale = rewriter.create( binder.getLoc(), rewriter.getType(), scale); zeropoint = rewriter.create( @@ -2272,9 +2263,9 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( // Extract the fill value and dtype // ONNX requires value attr to be a tensor if (!attr) { - attr = DenseElementsAttr::get( - resultType.toBuiltinTensor().clone(resultDType), - rewriter.getFloatAttr(resultDType, 0.0)); + attr = + DenseElementsAttr::get(resultType.toBuiltinTensor(), + rewriter.getFloatAttr(resultDType, 0.0)); } // If its a dense resource attr we need to convert to a dense type: diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index fa338e1ba90d..3c6d82e103b5 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -408,20 +408,11 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(1.0)); - auto q = [&](Type qty) -> Type { - if (qty.isSignedInteger(8)) - return rewriter.getType(); - if (qty.isUnsignedInteger(8)) - return rewriter.getType(); - if (qty.isSignedInteger(32)) - return rewriter.getType(); - return {}; - }; + auto lhsQTy = getQTorchTypeFromTorchIntType(lhsTy); + auto rhsQTy = getQTorchTypeFromTorchIntType(rhsTy); - Type lhsQTy = rewriter.getType( - lhsTy.getOptionalSizes(), q(lhsTy.getDtype())); - Type rhsQTy = rewriter.getType( - rhsTy.getOptionalSizes(), q(rhsTy.getDtype())); + if (!lhsQTy || !rhsQTy) + return rewriter.notifyMatchFailure(binder.op, "failed to get qtype"); lhs = rewriter.create( binder.getLoc(), lhsQTy, lhs, scale, lhsZp); diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 1e640979bf3a..edcbaa7d5173 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -177,22 +177,13 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( "requires known result dtype"); if (scaleTy.getSizes().size() == 0) { - Type qTy = resultType.getDtype(); - - if (qTy.isUnsignedInteger(8)) { - qTy = rewriter.getType(); - } else if (qTy.isSignedInteger(8)) { - qTy = rewriter.getType(); - } else if (qTy.isSignedInteger(32)) { - qTy = rewriter.getType(); - } else { + auto qTensorTy = getQTorchTypeFromTorchIntType(resultType); + if (!qTensorTy) { return rewriter.notifyMatchFailure(binder.op, "unsupported result dtype"); } - auto qTensorTy = rewriter.getType( - resultType.getOptionalSizes(), qTy); - auto torchqTy = Torch::getScalarTypeForType(qTy); + auto torchqTy = Torch::getScalarTypeForType(qTensorTy.getDtype()); Value tyConst = rewriter.create( binder.getLoc(), rewriter.getType(), @@ -311,8 +302,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( c = rewriter.create(binder.getLoc(), cTy, c); - cTy = dyn_cast( - getQTorchTypeFromTorchIntType(resultType)); + cTy = getQTorchTypeFromTorchIntType(resultType); Value dtyVal = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr( @@ -2963,28 +2953,15 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Torch::ValueTensorType resultType; llvm::SmallVector operands; std::string mode, nearest_mode, coordTfMode; + int64_t antialias, exclude_outside; + float extrapolation_value; Value noneVal = rewriter.create(binder.getLoc()); - if (auto attr = binder.op->getAttr("torch.onnx.antialias")) { - return rewriter.notifyMatchFailure( - binder.op, - "unimplemented: support not present for antialias attribute"); - } if (auto attr = binder.op->getAttr("torch.onnx.axes")) { return rewriter.notifyMatchFailure( binder.op, "unimplemented: support not present for axes attribute"); } - if (auto attr = binder.op->getAttr("torch.onnx.exclude_outside")) { - return rewriter.notifyMatchFailure( - binder.op, "unimplemented: support not present for " - "exclude_outside attribute"); - } - if (auto attr = binder.op->getAttr("torch.onnx.extrapolation_value")) { - return rewriter.notifyMatchFailure( - binder.op, "unimplemented: support not present for " - "extrapolation_value attribute"); - } if (auto attr = binder.op->getAttr("torch.onnx.keep_aspect_ratio_policy")) { return rewriter.notifyMatchFailure( @@ -2997,18 +2974,38 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.customOpNameStringAttr(mode, "mode", "nearest") || binder.customOpNameStringAttr( coordTfMode, "coordinate_transformation_mode", "half_pixel") || + binder.s64IntegerAttr(antialias, "antialias", 0) || + binder.s64IntegerAttr(exclude_outside, "exclude_outside", 0) || + binder.f32FloatAttr(extrapolation_value, "extrapolation_value", + 0.0) || binder.customOpNameStringAttr(nearest_mode, "nearest_mode", "round_prefer_floor")) return failure(); + if (antialias != 0) { + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented: support not present for antialias attribute"); + } + if (exclude_outside != 0) { + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: support not present for " + "exclude_outside attribute"); + } + if (extrapolation_value != 0.0) { + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: support not present for " + "extrapolation_value attribute"); + } if (coordTfMode == "tf_crop_and_resize") return rewriter.notifyMatchFailure( binder.op, "unimplemented: coordinate transformation mode: " "tf_crop_and_resize"); - if (mode == "nearest" && coordTfMode != "asymmetric") { + if (mode == "nearest" && coordTfMode != "asymmetric" && + coordTfMode != "half_pixel") { return rewriter.notifyMatchFailure( binder.op, "unimplemented: support not present for coord tf mode " - "except asymmetric"); + "except asymmetric and half_pixel"); } unsigned rank = dyn_cast(operands[0].getType()) @@ -3302,7 +3299,145 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( rewriter.replaceOpWithNewOp( binder.op, resultType, inputLTNegLambd, inputPlusBias, inputSubBiasOrZero); + return success(); + }); + patterns.onOp("SequenceAt", 11, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value inputSequence, position; + if (binder.tensorListOperandAtIndex(inputSequence, 0) || + binder.tensorOperandAtIndex(position, 1) || + binder.tensorResultType(resultType)) + return failure(); + + Value index = rewriter.create( + binder.getLoc(), rewriter.getType(), + position); + rewriter.replaceOpWithNewOp( + binder.op, resultType, inputSequence, index); + return success(); + }); + patterns.onOp( + "SequenceEmpty", 11, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ListType resultType; + int64_t dtypeIntOnnx; + if (binder.s64IntegerAttr(dtypeIntOnnx, "dtype", 1) || + binder.tensorListResultType(resultType)) + return failure(); + + std::optional dtypeIntTorch = + onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx); + if (!dtypeIntTorch.has_value()) { + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented support for the given dtype conversion"); + } + Value constDtype = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value())); + + Value shapeList = createConstantIntList(binder, rewriter, {}); + Value cstNone = rewriter.create(binder.getLoc()); + + Value self = rewriter.create( + binder.op->getLoc(), resultType.getContainedType(), shapeList, + /*dtype=*/constDtype, + /*layout=*/cstNone, + /*device=*/cstNone, /*pinMemory=*/cstNone, + /*memoryFormat=*/cstNone); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, llvm::SmallVector{self}); + return success(); + }); + patterns.onOp( + "SequenceErase", 11, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ListType resultType; + Value inputSequence, position; + if (binder.tensorListOperandAtIndex(inputSequence, 0) || + binder.tensorListResultType(resultType)) + return failure(); + + Value length = rewriter.create( + binder.getLoc(), rewriter.getType(), inputSequence); + + Value cstNone = rewriter.create(binder.getLoc()); + Value cstOne = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + if (binder.op->getNumOperands() == 1) { + // If True, it means that the `position` arg is missing and + // the last tensor from the list has to be erased. + Value lengthMinusOne = rewriter.create( + binder.getLoc(), length, cstOne); + rewriter.replaceOpWithNewOp( + binder.op, resultType, inputSequence, /*start=*/cstNone, + /*end=*/lengthMinusOne, /*step=*/cstOne); + return success(); + } + + if (binder.tensorOperandAtIndex(position, 1)) + return failure(); + + Value positionInt = rewriter.create( + binder.getLoc(), rewriter.getType(), position); + // Handling negative position value. + Value cstZero = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + Value isPositionNegative = rewriter.create( + binder.getLoc(), positionInt, cstZero); + isPositionNegative = rewriter.create( + binder.getLoc(), isPositionNegative); + Value finalOffset = rewriter.create( + binder.getLoc(), isPositionNegative, length); + positionInt = rewriter.create( + binder.getLoc(), positionInt, finalOffset); + + Value listBeforePosition = rewriter.create( + binder.getLoc(), resultType, inputSequence, /*start=*/cstNone, + /*end=*/positionInt, /*step=*/cstOne); + Value positionPlusOne = rewriter.create( + binder.getLoc(), positionInt, cstOne); + Value listAfterPosition = rewriter.create( + binder.getLoc(), resultType, inputSequence, + /*start=*/positionPlusOne, + /*end=*/length, /*step=*/cstOne); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, listBeforePosition, listAfterPosition); + return success(); + }); + patterns.onOp( + "SequenceInsert", 11, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ListType resultType; + Value inputSequence, position, insertValue; + if (binder.tensorListOperandAtIndex(inputSequence, 0) || + binder.tensorOperandAtIndex(insertValue, 1) || + binder.tensorListResultType(resultType)) + return failure(); + + if (binder.op->getNumOperands() == 1) { + // If True, it means that the `position` arg is missing and + // the tensor has to be inserted at the end of the list. + Value length = rewriter.create( + binder.getLoc(), rewriter.getType(), + inputSequence); + rewriter.replaceOpWithNewOp( + binder.op, inputSequence, /*idx=*/length, + /*el=*/insertValue); + return success(); + } + + if (binder.tensorOperandAtIndex(position, 2)) + return failure(); + Value positionInt = rewriter.create( + binder.getLoc(), rewriter.getType(), position); + rewriter.create(binder.getLoc(), inputSequence, + /*idx=*/positionInt, + /*el=*/insertValue); + rewriter.replaceOp(binder.op, inputSequence); return success(); }); } diff --git a/lib/Conversion/TorchOnnxToTorch/Utils.cpp b/lib/Conversion/TorchOnnxToTorch/Utils.cpp index e7baf2e243fc..bec6ade4270c 100644 --- a/lib/Conversion/TorchOnnxToTorch/Utils.cpp +++ b/lib/Conversion/TorchOnnxToTorch/Utils.cpp @@ -28,7 +28,8 @@ Value mlir::torch::onnx_c::createConstantIntList( cstValue); } -Type mlir::torch::onnx_c::getQTorchTypeFromTorchIntType(Type ty) { +Torch::ValueTensorType +mlir::torch::onnx_c::getQTorchTypeFromTorchIntType(Type ty) { Torch::ValueTensorType tty = dyn_cast(ty); if (!tty) return nullptr; @@ -40,6 +41,8 @@ Type mlir::torch::onnx_c::getQTorchTypeFromTorchIntType(Type ty) { dty = Torch::QUInt8Type::get(ctx); if (dty.isSignedInteger(8)) dty = Torch::QInt8Type::get(ctx); + if (dty.isSignedInteger(16)) + dty = Torch::QInt16Type::get(ctx); if (dty.isSignedInteger(32)) dty = Torch::QInt32Type::get(ctx); diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index b9b0fb0ae5d7..dc8b5d431002 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -2579,6 +2579,8 @@ class ConvertSparseOperatorOp : public OpConversionPattern { SmallVector ConvertSparseOperatorOp::legalizedNames = { "torch.aten._to_dense", "torch.aten._to_sparse", "torch.aten._to_csr", "torch.aten._to_csc", "torch.aten._to_bsr", "torch.aten._to_bsc", + "torch.aten.to_dense", "torch.aten.to_sparse", "torch.aten.to_csr", + "torch.aten.to_csc", "torch.aten.to_bsr", "torch.aten.to_bsc", }; } // namespace diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 092f7f90059e..c2e89e078eca 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -146,12 +146,11 @@ class ConvertAtenMmOp : public OpConversionPattern { "mismatching contracting dimension for torch.aten.mm")); } - auto resultTy = cast(op.getType()); - auto resultDTy = resultTy.toBuiltinTensor().getElementType(); - Type newResultType = getTypeConverter()->convertType(op.getType()); - Type elementType = cast(newResultType).getElementType(); - auto accumulatorDType = getDefaultAccType(rewriter, resultDTy); - if (accumulatorDType != resultDTy) { + TensorType resultType = + cast(getTypeConverter()->convertType(op.getType())); + Type elementType = resultType.getElementType(); + auto accumulatorDType = getDefaultAccType(rewriter, elementType); + if (accumulatorDType != resultType.getElementType()) { elementType = accumulatorDType; } Value zeroFill = createZeroInitTensor( @@ -197,18 +196,16 @@ class ConvertAtenMmOp : public OpConversionPattern { .getResult(0); } - if (accumulatorDType != resultDTy) { - Type resultElementType = - cast(newResultType).getElementType(); + if (accumulatorDType != resultType.getElementType()) { matmul = torch_to_linalg::convertTensorToElementType( - rewriter, loc, matmul, resultElementType); + rewriter, loc, matmul, resultType.getElementType()); } // When constructed with just dynamic sizes, EmptyOp will have a result // type which has all `?`'s for dimensions, which might not be the result // type of `op`. The constraints on later linalg ops means that the result // of the MatmulOp will have this type too. So cast it to the desired type // so that in the end we have the original result type. - rewriter.replaceOpWithNewOp(op, newResultType, matmul); + rewriter.replaceOpWithNewOp(op, resultType, matmul); return success(); } diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index 0aa919fe04a6..46b51558f13d 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -559,6 +559,8 @@ bool torch_to_linalg::isUnsignedTorchType(Type type) { return false; if (isa(type)) return true; + if (isa(type)) + return false; if (isa(type)) return false; if (auto intTy = dyn_cast(type)) diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 850703a80d88..4b01d88223b7 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -1311,7 +1311,7 @@ static OpFoldResult naryFolderHelper(ArrayRef operands, Type ty, return nullptr; auto dty = resultTy.getDtype(); - auto resultBTy = resultTy.toBuiltinTensor().clone(dty); + auto resultBTy = resultTy.toBuiltinTensor(); auto fpTy = dyn_cast(dty); auto intTy = dyn_cast(dty); @@ -1521,7 +1521,7 @@ OpFoldResult AtenEqTensorOp::fold(FoldAdaptor adaptor) { if (!ty || !ty.hasDtype() || !ty.hasSizes()) return nullptr; - auto bty = ty.toBuiltinTensor().clone(ty.getDtype()); + auto bty = ty.toBuiltinTensor(); if (!bty.hasStaticShape()) return nullptr; @@ -1635,7 +1635,6 @@ static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs, return nullptr; auto ctx = lhs.getContext(); - auto resultETy = resultTy.getDtype(); auto tensorETy = cast(lhs.getType()).getElementType(); if (lhs.isSplat()) { if (auto intAttr = dyn_cast(rhs)) { @@ -1647,8 +1646,7 @@ static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs, unsign ? tensorAP.getZExtValue() : tensorAP.getSExtValue(), !unsign); auto resultBool = intFolder(tensorAP, scalarAP, unsign); auto resultAP = IntegerAttr::get(IntegerType::get(ctx, 1), resultBool); - return DenseElementsAttr::get(resultTy.toBuiltinTensor().clone(resultETy), - resultAP); + return DenseElementsAttr::get(resultTy.toBuiltinTensor(), resultAP); } if (auto floatAttr = dyn_cast(rhs)) { @@ -1657,8 +1655,7 @@ static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs, auto resultBool = fpFolder(tensorAP.convertToDouble(), scalarAP.convertToDouble()); auto resultAP = IntegerAttr::get(IntegerType::get(ctx, 1), resultBool); - return DenseElementsAttr::get(resultTy.toBuiltinTensor().clone(resultETy), - resultAP); + return DenseElementsAttr::get(resultTy.toBuiltinTensor(), resultAP); } return nullptr; } @@ -1681,8 +1678,7 @@ static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs, auto resultBool = intFolder(tensorAP, scalarAP, unsign); values.push_back(resultBool); } - return DenseElementsAttr::get(resultTy.toBuiltinTensor().clone(resultETy), - values); + return DenseElementsAttr::get(resultTy.toBuiltinTensor(), values); } if (auto floatAttr = dyn_cast(rhs)) { @@ -1693,8 +1689,7 @@ static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs, fpFolder(tensorAP.convertToDouble(), scalarAP.convertToDouble()); values.push_back(resultBool); } - return DenseElementsAttr::get(resultTy.toBuiltinTensor().clone(resultETy), - values); + return DenseElementsAttr::get(resultTy.toBuiltinTensor(), values); } return nullptr; @@ -1844,7 +1839,7 @@ static OpFoldResult unaryPromoteFolder(DenseElementsAttr operand, if (!fpTy && !intTy) return nullptr; - auto resultBTy = resultTy.toBuiltinTensor().clone(resultTy.getDtype()); + auto resultBTy = resultTy.toBuiltinTensor(); bool splat = operand.isSplat(); bool withinMaxFold = resultBTy.hasStaticShape() && resultBTy.getNumElements() <= kMaxFold; @@ -2192,7 +2187,7 @@ OpFoldResult AtenSelectIntOp::fold(FoldAdaptor adaptor) { return nullptr; auto selfTy = cast(self.getType()); - auto bty = ty.toBuiltinTensor().clone(ty.getDtype()); + auto bty = ty.toBuiltinTensor(); if (!bty.hasStaticShape()) return nullptr; @@ -2671,8 +2666,7 @@ LogicalResult AtenSortOp::fold(FoldAdaptor adaptor, if (!indicesTensorType.hasDtype()) return failure(); - auto indicesType = - indicesTensorType.toBuiltinTensor().clone(indicesTensorType.getDtype()); + auto indicesType = indicesTensorType.toBuiltinTensor(); if (!indicesType || !indicesType.hasStaticShape()) return failure(); @@ -3627,9 +3621,8 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) { return nullptr; if (input && input.isSplat()) - return DenseElementsAttr::get( - outType.toBuiltinTensor().clone(inType.getDtype()), - input.getSplatValue()); + return DenseElementsAttr::get(outType.toBuiltinTensor(), + input.getSplatValue()); int count = 1; for (auto dim : outType.getSizes()) @@ -3667,8 +3660,7 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) { for (int i = begin; i < limit; i += stride) values.push_back(input.getValues()[i]); - return DenseElementsAttr::get( - outType.toBuiltinTensor().clone(inType.getDtype()), values); + return DenseElementsAttr::get(outType.toBuiltinTensor(), values); } // If the input and output shapes are the same we can just fold: @@ -4007,7 +3999,7 @@ OpFoldResult AtenTensorOp::fold(FoldAdaptor adaptor) { if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype()) return nullptr; Type eTy = resultTy.getDtype(); - ShapedType shapedTy = resultTy.toBuiltinTensor().clone(eTy); + ShapedType shapedTy = resultTy.toBuiltinTensor(); SmallVector data; if (matchPattern(getData(), m_TorchListOfConstantInts(data)) && @@ -4028,7 +4020,7 @@ OpFoldResult AtenTensorIntOp::fold(FoldAdaptor adaptor) { if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype()) return nullptr; Type eTy = resultTy.getDtype(); - ShapedType shapedTy = resultTy.toBuiltinTensor().clone(eTy); + ShapedType shapedTy = resultTy.toBuiltinTensor(); int64_t data; if (matchPattern(getT(), m_TorchConstantInt(&data))) { @@ -4048,7 +4040,7 @@ OpFoldResult AtenTensorFloatOp::fold(FoldAdaptor adaptor) { if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype()) return nullptr; Type eTy = resultTy.getDtype(); - ShapedType shapedTy = resultTy.toBuiltinTensor().clone(eTy); + ShapedType shapedTy = resultTy.toBuiltinTensor(); double data; if (matchPattern(getT(), m_TorchConstantFloat(&data))) { @@ -4221,7 +4213,7 @@ OpFoldResult AtenIndexSelectOp::fold(FoldAdaptor adaptor) { : selfAttr.getValues()[indexInt]; auto dty = resultTy.getDtype(); - auto attrTy = resultTy.toBuiltinTensor().clone(dty); + auto attrTy = resultTy.toBuiltinTensor(); if (auto floatAttr = dyn_cast(splattr)) return DenseElementsAttr::get( attrTy, FloatAttr::get(dty, floatAttr.getValueAsDouble())); @@ -4414,7 +4406,7 @@ static Attribute getBroadcastedAttr(Attribute attr, ValueTensorType ty) { if (!valueDense.isSplat()) return nullptr; auto splattr = valueDense.getSplatValue(); - auto attrty = ty.toBuiltinTensor().clone(dty); + auto attrty = ty.toBuiltinTensor(); return DenseElementsAttr::get(attrty, splattr); } @@ -4422,7 +4414,7 @@ static Attribute getBroadcastedAttr(Attribute attr, ValueTensorType ty) { if (!isa(dty)) return nullptr; int64_t intval = intAttr.getInt(); - auto attrty = ty.toBuiltinTensor().clone(dty); + auto attrty = ty.toBuiltinTensor(); return DenseElementsAttr::get(attrty, IntegerAttr::get(dty, intval)); } @@ -4430,7 +4422,7 @@ static Attribute getBroadcastedAttr(Attribute attr, ValueTensorType ty) { if (!isa(dty)) return nullptr; double dblval = fpAttr.getValueAsDouble(); - auto attrty = ty.toBuiltinTensor().clone(dty); + auto attrty = ty.toBuiltinTensor(); return DenseElementsAttr::get(attrty, FloatAttr::get(dty, dblval)); } diff --git a/lib/Dialect/Torch/IR/TorchTypes.cpp b/lib/Dialect/Torch/IR/TorchTypes.cpp index 6735bb37e48b..c46865ee5fed 100644 --- a/lib/Dialect/Torch/IR/TorchTypes.cpp +++ b/lib/Dialect/Torch/IR/TorchTypes.cpp @@ -185,7 +185,8 @@ static bool isValidTorchDtype(Type dtype) { dtype = cast(dtype).getElementType(); } // Torch quantized types. - if (isa(dtype)) + if (isa(dtype)) return true; // Builtin floating point types. if (isa(dtype)) @@ -453,12 +454,7 @@ ValueTensorType::getWithLeastStaticInformation(MLIRContext *context) { } static Type convertDtypeToBuiltinElementType(MLIRContext *context, Type dtype) { - if (auto floatType = dyn_cast(dtype)) { - return dtype; - } else if (auto integerType = dyn_cast(dtype)) { - return IntegerType::get(context, integerType.getWidth(), - IntegerType::Signless); - } else if (isa(dtype)) { + if (isa(dtype)) { return dtype; } @@ -468,6 +464,9 @@ static Type convertDtypeToBuiltinElementType(MLIRContext *context, Type dtype) { if (isa(dtype)) return IntegerType::get(context, 8, IntegerType::Signless); + if (isa(dtype)) + return IntegerType::get(context, 16, IntegerType::Signless); + if (isa(dtype)) return IntegerType::get(context, 32, IntegerType::Signless); @@ -480,11 +479,11 @@ static Type convertDtypeToBuiltinElementType(MLIRContext *context, Type dtype) { TensorType ValueTensorType::toBuiltinTensor() const { if (!hasDtype()) return nullptr; - if (!hasSizes()) - return UnrankedTensorType::get(getDtype()); Type elementType = convertDtypeToBuiltinElementType(getContext(), getDtype()); if (!elementType) return nullptr; + if (!hasSizes()) + return UnrankedTensorType::get(elementType); return RankedTensorType::get(makeShapeLLVMCompatible(getSizes()), elementType, getOptionalSparsity()); } diff --git a/lib/Dialect/Torch/Transforms/MatchQuantizedOps.cpp b/lib/Dialect/Torch/Transforms/MatchQuantizedOps.cpp index c237ede12479..b571003940cb 100644 --- a/lib/Dialect/Torch/Transforms/MatchQuantizedOps.cpp +++ b/lib/Dialect/Torch/Transforms/MatchQuantizedOps.cpp @@ -21,10 +21,12 @@ using namespace mlir::torch::Torch; namespace { Type getQuantizedType(MLIRContext *context, Type t) { - if (t.isSignlessInteger(8)) + if (t.isSignlessInteger(8) || t.isUnsignedInteger(8)) return Torch::QUInt8Type::get(context); if (t.isInteger(8) || t.isSignedInteger(8)) return Torch::QInt8Type::get(context); + if (t.isInteger(16)) + return Torch::QInt16Type::get(context); if (t.isInteger(32)) return Torch::QInt32Type::get(context); return {}; diff --git a/lib/Dialect/Torch/Utils/TorchUpstream.cpp b/lib/Dialect/Torch/Utils/TorchUpstream.cpp index 2dce14ef964c..c4c42f7fe0e3 100644 --- a/lib/Dialect/Torch/Utils/TorchUpstream.cpp +++ b/lib/Dialect/Torch/Utils/TorchUpstream.cpp @@ -21,7 +21,7 @@ static inline bool isQIntType(ScalarType t) { // Don't forget to extend this when adding new QInt types return t == ScalarType::QInt8 || t == ScalarType::QUInt8 || t == ScalarType::QInt32 || t == ScalarType::QUInt4x2 || - t == ScalarType::QUInt2x4; + t == ScalarType::QUInt2x4 || t == ScalarType::QInt16; } //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 511ff770beaa..7ba3157b8986 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -112,6 +112,8 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) { return torch_upstream::ScalarType::QUInt8; if (isa(type)) return torch_upstream::ScalarType::QInt8; + if (isa(type)) + return torch_upstream::ScalarType::QInt16; if (isa(type)) return torch_upstream::ScalarType::QInt32; if (isa(type)) { @@ -123,6 +125,14 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) { if (complexElemType.isF64()) return torch_upstream::ScalarType::ComplexDouble; } + if (isa(type)) + return torch_upstream::ScalarType::Float8_e5m2; + if (isa(type)) + return torch_upstream::ScalarType::Float8_e4m3fn; + if (isa(type)) + return torch_upstream::ScalarType::Float8_e5m2fnuz; + if (isa(type)) + return torch_upstream::ScalarType::Float8_e4m3fnuz; llvm::report_fatal_error("unhandled type for getScalarTypeForType"); } Type Torch::getTypeForTorchType( @@ -163,6 +173,8 @@ Torch::getTypeForScalarType(MLIRContext *context, return QUInt8Type::get(context); case torch_upstream::ScalarType::QInt8: return QInt8Type::get(context); + case torch_upstream::ScalarType::QInt16: + return QInt16Type::get(context); case torch_upstream::ScalarType::QInt32: return QInt32Type::get(context); case torch_upstream::ScalarType::ComplexHalf: @@ -171,6 +183,14 @@ Torch::getTypeForScalarType(MLIRContext *context, return mlir::ComplexType::get(Float32Type::get(context)); case torch_upstream::ScalarType::ComplexDouble: return mlir::ComplexType::get(Float64Type::get(context)); + case torch_upstream::ScalarType::Float8_e5m2: + return Float8E5M2Type::get(context); + case torch_upstream::ScalarType::Float8_e4m3fn: + return Float8E4M3FNType::get(context); + case torch_upstream::ScalarType::Float8_e5m2fnuz: + return Float8E5M2FNUZType::get(context); + case torch_upstream::ScalarType::Float8_e4m3fnuz: + return Float8E4M3FNUZType::get(context); case torch_upstream::ScalarType::Undefined: return failure(); default: diff --git a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp index 70554680ef16..0f2533e063f0 100644 --- a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp +++ b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp @@ -164,7 +164,18 @@ void mlir::torch::TorchConversion::setupBackendTypeConversion( ConversionTarget &target, TypeConverter &typeConverter) { auto valueTensorTypeConversion = [](Torch::ValueTensorType type) -> std::optional { - return type.toBuiltinTensor(); + auto builtinType = type.toBuiltinTensor(); + if (!builtinType) + return std::nullopt; + + // convert any integer type to signless + if (type.getDtype().isInteger()) { + return builtinType.clone(IntegerType::get( + builtinType.getContext(), type.getDtype().getIntOrFloatBitWidth(), + IntegerType::Signless)); + } + + return builtinType; }; setupValueTensorToBuiltinTensorConversion(target, typeConverter, valueTensorTypeConversion); @@ -180,9 +191,18 @@ void mlir::torch::TorchConversion::setupBackendTypeConversionForStablehlo( auto valueTensorTypeConversion = [](Torch::ValueTensorType type) -> std::optional { auto builtinType = type.toBuiltinTensor(); + if (!builtinType) + return std::nullopt; + + // convert signed integer type to signless, keep unsigned as unsigned if (type.getDtype().isUnsignedInteger()) { return builtinType.clone(type.getDtype()); + } else if (type.getDtype().isSignedInteger()) { + return builtinType.clone(IntegerType::get( + builtinType.getContext(), type.getDtype().getIntOrFloatBitWidth(), + IntegerType::Signless)); } + return builtinType; }; setupValueTensorToBuiltinTensorConversion(target, typeConverter, diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 9dcb3c285dc8..2a73325c7d76 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -99,6 +99,10 @@ FloatAttr, BF16Type, ComplexType, + Float8E5M2Type, + Float8E4M3FNType, + Float8E5M2FNUZType, + Float8E4M3FNUZType, F16Type, F32Type, F64Type, @@ -147,6 +151,10 @@ torch.complex32: "complex", torch.complex64: "complex", torch.complex128: "complex", + torch.float8_e5m2: "f8E5M2", + torch.float8_e4m3fn: "f8E4M3FN", + torch.float8_e5m2fnuz: "f8E5M2FNUZ", + torch.float8_e4m3fnuz: "f8E4M3FNUZ", } TORCH_DTYPE_TO_MLIR_TYPE: Dict[torch.dtype, Callable[[], IrType]] = { @@ -165,6 +173,10 @@ torch.complex32: lambda: ComplexType.get(F16Type.get()), torch.complex64: lambda: ComplexType.get(F32Type.get()), torch.complex128: lambda: ComplexType.get(F64Type.get()), + torch.float8_e5m2: lambda: Float8E5M2Type.get(), + torch.float8_e5m2fnuz: lambda: Float8E5M2FNUZType.get(), + torch.float8_e4m3fn: lambda: Float8E4M3FNType.get(), + torch.float8_e4m3fnuz: lambda: Float8E4M3FNUZType.get(), } TORCH_DTYPE_TO_NPY_TYPE = { @@ -203,6 +215,10 @@ # torch.quint8: 13, # torch.qint32 14 torch.bfloat16: 15, + torch.float8_e5m2: 23, + torch.float8_e4m3fn: 24, + torch.float8_e5m2fnuz: 25, + torch.float8_e4m3fnuz: 26, } TORCH_MEMORY_FORMAT_TO_INT = { @@ -1429,7 +1445,8 @@ def import_nodes( operands = [self._import_argument(loc, arg) for arg in node.args[0]] func_dialect.ReturnOp(operands, loc=loc) - self._create_bind_symbolic_shape_ops(loc, node) + if import_symbolic_shape_expressions: + self._create_bind_symbolic_shape_ops(loc, node) def _promote_symbolic_scalar_int_float(self, loc, graph, param): temp_target = torch.ops.aten.Float.Scalar diff --git a/python/torch_mlir/extras/onnx_importer.py b/python/torch_mlir/extras/onnx_importer.py index e0d3529d942e..4c1e0b9e9aed 100644 --- a/python/torch_mlir/extras/onnx_importer.py +++ b/python/torch_mlir/extras/onnx_importer.py @@ -34,6 +34,7 @@ ) from e from typing import Optional, List, Dict, Tuple +import warnings from dataclasses import dataclass @@ -579,6 +580,10 @@ def tensor_proto_to_builtin_type(self, tp: onnx.TensorProto) -> IrType: def type_proto_to_type(self, tp: onnx.TypeProto) -> IrType: if tp == "": + warnings.warn( + "Found a node without a valid type proto. Consider updating the opset_version of" + " the model and/or running the importer with the flag '--clear-domain'." + ) return self.get_none_type() tt = tp.tensor_type diff --git a/python/torch_mlir/tools/import_onnx/__main__.py b/python/torch_mlir/tools/import_onnx/__main__.py index 92ae3c7eb356..bca87cee7f59 100644 --- a/python/torch_mlir/tools/import_onnx/__main__.py +++ b/python/torch_mlir/tools/import_onnx/__main__.py @@ -20,6 +20,7 @@ import sys import onnx +import onnx.version from ...extras import onnx_importer @@ -81,6 +82,16 @@ def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto: raw_model = onnx.load(args.input_file, load_external_data=False) onnx.load_external_data_for_model(raw_model, args.data_dir) + if args.opset_version: + raw_model = onnx.version_converter.convert_version( + raw_model, args.opset_version + ) + + if args.clear_domain: + graph = raw_model.graph + for n in graph.node: + n.ClearField("domain") + # Run the checker to test whether the file is above the threshold for # in-memory shape inference. If not, go ahead and do the shape inference. try: @@ -149,6 +160,14 @@ def parse_arguments(argv=None) -> argparse.Namespace: action=argparse.BooleanOptionalAction, help="Toggle data propogation for onnx shape inference", ) + parser.add_argument( + "--clear-domain", + dest="clear_domain", + default=False, + action=argparse.BooleanOptionalAction, + help="If enabled, this will clear the domain attribute from each node" + " in the onnx graph before performing shape inference.", + ) parser.add_argument( "--keep-temps", action="store_true", help="Keep intermediate files" ) @@ -170,6 +189,12 @@ def parse_arguments(argv=None) -> argparse.Namespace: " Defaults to the directory of the input file.", type=Path, ) + parser.add_argument( + "--opset-version", + help="Allows specification of a newer opset_version to update the model" + " to before importing to MLIR. This can sometime assist with shape inference.", + type=int, + ) args = parser.parse_args(argv) return args diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 3f437fc4c5c1..5b33fd17471b 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -748,6 +748,19 @@ func.func @test_dequantizelinear_si8(%arg0: !torch.vtensor<[6],si8>, %arg1: !tor // ----- +// CHECK-LABEL: @test_dequantizelinear_si16 +func.func @test_dequantizelinear_si16(%arg0: !torch.vtensor<[6],si16>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],si16>) -> !torch.vtensor<[6],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64} { + %0 = torch.operator "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) : (!torch.vtensor<[6],si16>, !torch.vtensor<[],f32>, !torch.vtensor<[],si16>) -> !torch.vtensor<[6],f32> + // CHECK: %[[SCALE:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[ZP:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],si16> -> !torch.int + // CHECK: %[[MAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[SCALE]], %[[ZP]] + // CHECK: %[[DEQ:.+]] = torch.aten.dequantize.self %[[MAKE]] + // CHECK: return %[[DEQ]] + return %0 : !torch.vtensor<[6],f32> +} + +// ----- + // CHECK-LABEL: @test_dequantizelinear_ui8 func.func @test_dequantizelinear_ui8(%arg0: !torch.vtensor<[6],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>) -> !torch.vtensor<[6],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64} { %0 = torch.operator "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) : (!torch.vtensor<[6],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>) -> !torch.vtensor<[6],f32> diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 0d72f2252abb..79958da59c77 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -2217,6 +2217,19 @@ func.func @test_sce_mean_3d_log_prob(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1: // ----- +// CHECK-LABEL: func.func @test_resize_sizes_nearest +func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %none = torch.constant.none + // CHECK: %[[STR:.+]] = torch.constant.str "nearest_half_pixel,round_prefer_floor" + // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %[[STR]], %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) { + torch.onnx.coordinate_transformation_mode = "half_pixel", + torch.onnx.mode = "nearest"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> + return %0 : !torch.vtensor<[?,?,?,?],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_resize_sizes_linear func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?], f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { @@ -2373,3 +2386,192 @@ func.func @test_shrink_hard(%arg0: !torch.vtensor<[5],f32>) -> !torch.vtensor<[5 %0 = torch.operator "onnx.Shrink"(%arg0) {torch.onnx.lambd = 1.500000e+00 : f32} : (!torch.vtensor<[5],f32>) -> !torch.vtensor<[5],f32> return %0 : !torch.vtensor<[5],f32> } + +// ----- + +// CHECK-LABEL: func.func @test_sequence_at +func.func @test_sequence_at(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2,3,4],f32>, %arg2: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[VTENSOR_1:.*]] = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[SEQUENCE:.*]] = torch.prim.ListConstruct %arg0, %arg1, %arg2 : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + // CHECK: %[[LENGTH:.*]] = torch.aten.len.t %[[SEQUENCE]] : !torch.list> -> !torch.int + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[ITEM:.*]] = torch.aten.item %[[VTENSOR]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[CMP:.*]] = torch.aten.lt.int %[[ITEM]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[COND:.*]] = torch.aten.Int.bool %[[CMP]] : !torch.bool -> !torch.int + // CHECK: %[[OFFSET:.*]] = torch.aten.mul.int %[[COND]], %[[LENGTH]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[POSITION:.*]] = torch.aten.add.int %[[ITEM]], %[[OFFSET]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[NONE]], %[[POSITION]], %[[INT1]] : !torch.list>, !torch.none, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[POSITION]], %[[INT1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE_1:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[ADD]], %[[LENGTH]], %[[INT1]] : !torch.list>, !torch.int, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[CONCAT_LIST:.*]] = torch.aten.add.t %[[SLICE]], %[[SLICE_1]] : !torch.list>, !torch.list> -> !torch.list> + // CHECK: %[[ITEM_0:.*]] = torch.aten.item %[[VTENSOR_1]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[RESULT:.*]] = torch.aten.__getitem__.t %[[CONCAT_LIST]], %[[ITEM_0]] : !torch.list>, !torch.int -> !torch.vtensor<[2,3,4],f32> + // CHECK: return %[[RESULT]] : !torch.vtensor<[2,3,4],f32> + %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor} : () -> !torch.vtensor<[],si64> + %1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor} : () -> !torch.vtensor<[],si64> + %2 = torch.operator "onnx.SequenceConstruct"(%arg0, %arg1, %arg2) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + %3 = torch.operator "onnx.SequenceErase"(%2, %0) : (!torch.list>, !torch.vtensor<[],si64>) -> !torch.list> + %4 = torch.operator "onnx.SequenceAt"(%3, %1) : (!torch.list>, !torch.vtensor<[],si64>) -> !torch.vtensor<[2,3,4],f32> + return %4 : !torch.vtensor<[2,3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_sequence_insert +func.func @test_sequence_insert(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2,3,4],f32>, %arg2: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<-3> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[VTENSOR_1:.*]] = torch.vtensor.literal(dense<-1> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[VTENSOR_2:.*]] = torch.vtensor.literal(dense<-1> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[SEQUENCE:.*]] = torch.prim.ListConstruct %arg0, %arg1, %arg2 : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + // CHECK: %[[LENGTH:.*]] = torch.aten.len.t %[[SEQUENCE]] : !torch.list> -> !torch.int + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[ITEM:.*]] = torch.aten.item %[[VTENSOR]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[CMP:.*]] = torch.aten.lt.int %[[ITEM]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[COND:.*]] = torch.aten.Int.bool %[[CMP]] : !torch.bool -> !torch.int + // CHECK: %[[OFFSET:.*]] = torch.aten.mul.int %[[COND]], %[[LENGTH]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[POSITION:.*]] = torch.aten.add.int %[[ITEM]], %[[OFFSET]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[NONE]], %[[POSITION]], %[[INT1]] : !torch.list>, !torch.none, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[POSITION]], %[[INT1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE_1:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[ADD]], %[[LENGTH]], %[[INT1]] : !torch.list>, !torch.int, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[CONCAT_LIST:.*]] = torch.aten.add.t %[[SLICE]], %[[SLICE_1]] : !torch.list>, !torch.list> -> !torch.list> + // CHECK: %[[ITEM_0:.*]] = torch.aten.item %[[VTENSOR_1]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: torch.aten.insert.t %[[CONCAT_LIST]], %[[ITEM_0]], %arg0 : !torch.list>, !torch.int, !torch.vtensor<[2,3,4],f32> + // CHECK: %[[ITEM_1:.*]] = torch.aten.item %[[VTENSOR_2]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[RESULT:.*]] = torch.aten.__getitem__.t %[[CONCAT_LIST]], %[[ITEM_1]] : !torch.list>, !torch.int -> !torch.vtensor<[2,3,4],f32> + // CHECK: return %[[RESULT]] : !torch.vtensor<[2,3,4],f32> + %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<-3> : tensor} : () -> !torch.vtensor<[],si64> + %1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<-1> : tensor} : () -> !torch.vtensor<[],si64> + %2 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<-1> : tensor} : () -> !torch.vtensor<[],si64> + %3 = torch.operator "onnx.SequenceConstruct"(%arg0, %arg1, %arg2) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + %4 = torch.operator "onnx.SequenceErase"(%3, %0) : (!torch.list>, !torch.vtensor<[],si64>) -> !torch.list> + %5 = torch.operator "onnx.SequenceInsert"(%4, %arg0, %1) : (!torch.list>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[],si64>) -> !torch.list> + %6 = torch.operator "onnx.SequenceAt"(%5, %2) : (!torch.list>, !torch.vtensor<[],si64>) -> !torch.vtensor<[2,3,4],f32> + return %6 : !torch.vtensor<[2,3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_sequence_erase_at_beginning +func.func @test_sequence_erase_at_beginning(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2,3,4],f32>, %arg2: !torch.vtensor<[2,3,4],f32>) -> !torch.list> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[SEQUENCE:.*]] = torch.prim.ListConstruct %arg0, %arg1, %arg2 : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + // CHECK: %[[LENGTH:.*]] = torch.aten.len.t %[[SEQUENCE]] : !torch.list> -> !torch.int + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[ITEM:.*]] = torch.aten.item %[[VTENSOR]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[CMP:.*]] = torch.aten.lt.int %[[ITEM]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[COND:.*]] = torch.aten.Int.bool %[[CMP]] : !torch.bool -> !torch.int + // CHECK: %[[OFFSET:.*]] = torch.aten.mul.int %[[COND]], %[[LENGTH]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[POSITION:.*]] = torch.aten.add.int %[[ITEM]], %[[OFFSET]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[NONE]], %[[POSITION]], %[[INT1]] : !torch.list>, !torch.none, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[POSITION]], %[[INT1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE_1:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[ADD]], %[[LENGTH]], %[[INT1]] : !torch.list>, !torch.int, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[RESULT:.*]] = torch.aten.add.t %[[SLICE]], %[[SLICE_1]] : !torch.list>, !torch.list> -> !torch.list> + // CHECK: return %[[RESULT]] : !torch.list> + %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor} : () -> !torch.vtensor<[],si64> + %3 = torch.operator "onnx.SequenceConstruct"(%arg0, %arg1, %arg2) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + %4 = torch.operator "onnx.SequenceErase"(%3, %0) : (!torch.list>, !torch.vtensor<[],si64>) -> !torch.list> + return %4 : !torch.list> +} + +// ----- + +// CHECK-LABEL: func.func @test_sequence_erase_at_end +func.func @test_sequence_erase_at_end(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2,3,4],f32>, %arg2: !torch.vtensor<[2,3,4],f32>) -> !torch.list> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<2> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[SEQUENCE:.*]] = torch.prim.ListConstruct %arg0, %arg1, %arg2 : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + // CHECK: %[[LENGTH:.*]] = torch.aten.len.t %[[SEQUENCE]] : !torch.list> -> !torch.int + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[ITEM:.*]] = torch.aten.item %[[VTENSOR]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[CMP:.*]] = torch.aten.lt.int %[[ITEM]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[COND:.*]] = torch.aten.Int.bool %[[CMP]] : !torch.bool -> !torch.int + // CHECK: %[[OFFSET:.*]] = torch.aten.mul.int %[[COND]], %[[LENGTH]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[POSITION:.*]] = torch.aten.add.int %[[ITEM]], %[[OFFSET]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[NONE]], %[[POSITION]], %[[INT1]] : !torch.list>, !torch.none, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[POSITION]], %[[INT1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE_1:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[ADD]], %[[LENGTH]], %[[INT1]] : !torch.list>, !torch.int, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[RESULT:.*]] = torch.aten.add.t %[[SLICE]], %[[SLICE_1]] : !torch.list>, !torch.list> -> !torch.list> + // CHECK: return %[[RESULT]] : !torch.list> + %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<2> : tensor} : () -> !torch.vtensor<[],si64> + %3 = torch.operator "onnx.SequenceConstruct"(%arg0, %arg1, %arg2) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + %4 = torch.operator "onnx.SequenceErase"(%3, %0) : (!torch.list>, !torch.vtensor<[],si64>) -> !torch.list> + return %4 : !torch.list> +} + +// ----- + +// CHECK-LABEL: func.func @test_sequence_erase_negative_idx +func.func @test_sequence_erase_negative_idx(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2,3,4],f32>, %arg2: !torch.vtensor<[2,3,4],f32>) -> !torch.list> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<-2> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[SEQUENCE:.*]] = torch.prim.ListConstruct %arg0, %arg1, %arg2 : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + // CHECK: %[[LENGTH:.*]] = torch.aten.len.t %[[SEQUENCE]] : !torch.list> -> !torch.int + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[ITEM:.*]] = torch.aten.item %[[VTENSOR]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[CMP:.*]] = torch.aten.lt.int %[[ITEM]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[COND:.*]] = torch.aten.Int.bool %[[CMP]] : !torch.bool -> !torch.int + // CHECK: %[[OFFSET:.*]] = torch.aten.mul.int %[[COND]], %[[LENGTH]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[POSITION:.*]] = torch.aten.add.int %[[ITEM]], %[[OFFSET]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[NONE]], %[[POSITION]], %[[INT1]] : !torch.list>, !torch.none, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[POSITION]], %[[INT1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE_1:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[ADD]], %[[LENGTH]], %[[INT1]] : !torch.list>, !torch.int, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[RESULT:.*]] = torch.aten.add.t %[[SLICE]], %[[SLICE_1]] : !torch.list>, !torch.list> -> !torch.list> + // CHECK: return %[[RESULT]] : !torch.list> + %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<-2> : tensor} : () -> !torch.vtensor<[],si64> + %3 = torch.operator "onnx.SequenceConstruct"(%arg0, %arg1, %arg2) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + %4 = torch.operator "onnx.SequenceErase"(%3, %0) : (!torch.list>, !torch.vtensor<[],si64>) -> !torch.list> + return %4 : !torch.list> +} + +// ----- + +// CHECK-LABEL: func.func @test_sequence_erase_empty +func.func @test_sequence_erase_empty() -> !torch.list> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[INT6:.*]] = torch.constant.int 6 + // CHECK: %[[SHAPE_LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[EMPTY_TENSOR:.*]] = torch.aten.empty.memory_format %[[SHAPE_LIST]], %[[INT6]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> + // CHECK: %[[SEQUENCE:.*]] = torch.prim.ListConstruct %[[EMPTY_TENSOR]] : (!torch.vtensor<[],f32>) -> !torch.list> + // CHECK: %[[LENGTH:.*]] = torch.aten.len.t %[[SEQUENCE]] : !torch.list> -> !torch.int + // CHECK: %[[NONE_0:.*]] = torch.constant.none + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[ITEM:.*]] = torch.aten.item %[[VTENSOR]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[CMP:.*]] = torch.aten.lt.int %[[ITEM]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[COND:.*]] = torch.aten.Int.bool %[[CMP]] : !torch.bool -> !torch.int + // CHECK: %[[OFFSET:.*]] = torch.aten.mul.int %[[COND]], %[[LENGTH]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[POSITION:.*]] = torch.aten.add.int %[[ITEM]], %[[OFFSET]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[NONE_0]], %[[POSITION]], %[[INT1]] : !torch.list>, !torch.none, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[POSITION]], %[[INT1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE_1:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[ADD]], %[[LENGTH]], %[[INT1]] : !torch.list>, !torch.int, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[RESULT:.*]] = torch.aten.add.t %[[SLICE]], %[[SLICE_1]] : !torch.list>, !torch.list> -> !torch.list> + // CHECK: return %[[RESULT]] : !torch.list> + %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor} : () -> !torch.vtensor<[],si64> + %1 = torch.operator "onnx.SequenceEmpty"() : () -> !torch.list> + %4 = torch.operator "onnx.SequenceErase"(%1, %0) : (!torch.list>, !torch.vtensor<[],si64>) -> !torch.list> + return %4 : !torch.list> +} + +// ----- + +// CHECK-LABEL: func.func @test_sequence_empty +func.func @test_sequence_empty() -> !torch.list> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT6:.*]] = torch.constant.int 6 + // CHECK: %[[SHAPE_LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[EMPTY_TENSOR:.*]] = torch.aten.empty.memory_format %[[SHAPE_LIST]], %[[INT6]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> + // CHECK: %[[RESULT:.*]] = torch.prim.ListConstruct %[[EMPTY_TENSOR]] : (!torch.vtensor<[],f32>) -> !torch.list> + // CHECK: return %[[RESULT]] : !torch.list> + %0 = torch.operator "onnx.SequenceEmpty"() : () -> !torch.list> + return %0 : !torch.list> +} diff --git a/test/Dialect/Torch/ops.mlir b/test/Dialect/Torch/ops.mlir index 8242321c3303..29ab52f9dab0 100644 --- a/test/Dialect/Torch/ops.mlir +++ b/test/Dialect/Torch/ops.mlir @@ -174,6 +174,7 @@ func.func @number_type_subtypes(%arg0: !torch.tensor, %arg1: !torch.list, % func.func private @tensor_legal_dtype$torch.qint8() -> !torch.tensor<*,!torch.qint8> func.func private @tensor_legal_dtype$torch.quint8() -> !torch.tensor<*,!torch.quint8> +func.func private @tensor_legal_dtype$torch.qint16() -> !torch.tensor<*,!torch.qint16> func.func @prim_list_construct$valid_shape_subtype(%arg0: !torch.vtensor<[1,53,56,96],f16>, %arg1: !torch.vtensor<[1,3,56,96],f16>) -> !torch.list> { %arg2 = "torch.prim.ListConstruct"(%arg0, %arg1) : (!torch.vtensor<[1,53,56,96],f16>, !torch.vtensor<[1,3,56,96],f16>) -> !torch.list> diff --git a/test/python/fx_importer/custom_op_test.py b/test/python/fx_importer/custom_op_test.py new file mode 100644 index 000000000000..dbbc5ba057af --- /dev/null +++ b/test/python/fx_importer/custom_op_test.py @@ -0,0 +1,86 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +# RUN: %PYTHON %s | FileCheck %s + +import torch +import torch.nn as nn +from torch.export import Dim +from torch.library import Library, impl, impl_abstract + +from torch_mlir import fx + + +def run(f): + print(f"{f.__name__}") + print("-" * len(f.__name__)) + f() + print() + + +@run +# CHECK-LABEL: test_tanh_sigmoid_cat_custom_op +# CHECK: func.func @main( +# CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>, +# CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>, +# CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> { +# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int +# CHECK: %[[S1:.+]] = torch.symbolic_int "s1" {min_val = {{[0-9]+}}, max_val = 100} : !torch.int +# CHECK: %[[S2:.+]] = torch.symbolic_int "s3" {min_val = {{[0-9]+}}, max_val = 50} : !torch.int +# CHECK: %[[S3:.+]] = torch.symbolic_int "s5" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]], %[[S2]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[ARG2]], [%[[S0]], %[[S3]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: %[[OP:.+]] = torch.operator "torch.my_custom_library.tanh_sigmoid_cat_op"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[OP]], [%[[S0]], %[[S1]], %[[S2]], %[[S3]]], affine_map<()[s0, s1, s2, s3] -> (s0, s2 + s3 + s1 * 2, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: return %[[OP]] : !torch.vtensor<[?,?,3],f32> +def test_tanh_sigmoid_cat_custom_op(): + + m = Library("my_custom_library", "DEF") + m.define("tanh_sigmoid_cat_op(Tensor x, Tensor y, Tensor z) -> Tensor") + + @impl(m, "tanh_sigmoid_cat_op", "CompositeExplicitAutograd") + def custom_op(x, y, z): + a = torch.tanh(x) + b = torch.sigmoid(y) + return torch.cat((a, a, b, z), dim=1) + + @impl_abstract("my_custom_library::tanh_sigmoid_cat_op") + def custom_op_meta(x, y, z): + result = custom_op(x, y, z) + return torch.empty_like(result) + + class TanhSigmoidCatCustomOp(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y, z): + return torch.ops.my_custom_library.tanh_sigmoid_cat_op(x, y, z) + + # Sample inputs + x = torch.randn(5, 2, 3) + y = torch.randn(5, 6, 3) + z = torch.randn(5, 4, 3) + + # Dynamic dim constraints + dim_n = Dim("n", min=5, max=10) + dim_x1 = Dim("x1", max=100) + dim_y1 = Dim("y1", max=50) + dim_z1 = Dim("z1") + dynamic_shapes = { + "x": {0: dim_n, 1: dim_x1}, + "y": {0: dim_n, 1: dim_y1}, + "z": {0: dim_n, 1: dim_z1}, + } + + m = fx.export_and_import( + TanhSigmoidCatCustomOp(), + x, + y, + z, + dynamic_shapes=dynamic_shapes, + import_symbolic_shape_expressions=True, + ) + print(m) diff --git a/test/python/fx_importer/sparse_test.py b/test/python/fx_importer/sparse_test.py index 41872b77e928..7c7198ef6f61 100644 --- a/test/python/fx_importer/sparse_test.py +++ b/test/python/fx_importer/sparse_test.py @@ -125,7 +125,7 @@ def sparse_export( # Zero preserving elt-wise unary op. if opname in {"abs", "neg", "relu", "sin"}: node.meta["sparsity"] = node.args[0].meta.get("sparsity", None) - elif opname == "_to_sparse": + elif opname == "_to_sparse" or opname == "to_sparse": dim = len(node.meta.get("val").shape) node.meta["sparsity"] = SparsityMeta( torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64 @@ -339,6 +339,14 @@ def forward(self, x, v): @run # +# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa)), posWidth = 64, crdWidth = 64 }> +# CHECK: func.func @main( +# CHECK-SAME: %[[A:.*0]]: !torch.vtensor<[8,8],f32,#[[$COO]]>, +# CHECK-SAME: %[[B:.*1]]: !torch.vtensor<[8,8],f32>) -> !torch.vtensor<[8,8],f32> { +# CHECK: %[[R:.*]] = torch.aten.{{matmul|mm}} %[[A]], %[[B]] : !torch.vtensor<[8,8],f32,#[[$COO]]>, !torch.vtensor<[8,8],f32> -> !torch.vtensor<[8,8],f32> +# CHECK: return %[[R]] : !torch.vtensor<[8,8],f32> +# CHECK: } +## # CHECK: torch.sparse # CHECK: tensor({{\[}}[8., 8., 8., 8., 8., 8., 8., 8.], # CHECK-COUNT-6: [8., 8., 8., 8., 8., 8., 8., 8.], @@ -360,7 +368,7 @@ def forward(self, x, y): dense_input = torch.ones(8, 8) sparse_input = dense_input.to_sparse_coo() m = export_and_import(net, sparse_input, dense_input) - # print(m) + print(m) # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. res1 = net(sparse_input, dense_input) @@ -500,12 +508,29 @@ def forward(self, x): @run # +# CHECK-LABEL: test_sparse_activation +# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton(nonunique, soa), d2 : singleton(soa)), posWidth = 64, crdWidth = 64 }> +# CHECK: func.func @main( +# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[2,2,2],f32,#[[$COO]]> { +# CHECK: %[[N1:.*]] = torch.constant.none +# CHECK: %[[N2:.*]] = torch.constant.none +# CHECK: %[[N3:.*]] = torch.constant.none +# CHECK: %[[R:.*]] = torch.operator "torch.aten.{{to_sparse|_to_sparse}}"(%[[A]], %[[N1]], %[[N2]], %[[N3]]) : (!torch.vtensor<[2,2,2],f32>, !torch.none, !torch.none, !torch.none) -> !torch.vtensor<[2,2,2],f32,#[[$COO]]> +# CHECK: return %[[R]] : !torch.vtensor<[2,2,2],f32,#[[$COO]]> +# CHECK: } +# # CHECK: torch.sparse # CHECK: tensor(indices=tensor({{\[}}[0, 0, 0, 0, 1, 1, 1, 1], # CHECK: [0, 0, 1, 1, 0, 0, 1, 1], # CHECK: [0, 1, 0, 1, 0, 1, 0, 1]{{\]}}), # CHECK: values=tensor([1., 1., 1., 1., 1., 1., 1., 1.]), # CHECK: size=(2, 2, 2), nnz=8, layout=torch.sparse_coo) +# CHECK: torch.mlir +# CHECK: [0 8] +# CHECK: [0 0 0 0 1 1 1 1] +# CHECK: [0 0 1 1 0 0 1 1] +# CHECK: [0 1 0 1 0 1 0 1] +# CHECK: [1. 1. 1. 1. 1. 1. 1. 1.] # def test_sparse_activation(): class SparseActivationCOO(torch.nn.Module): @@ -515,19 +540,19 @@ def forward(self, x): net = SparseActivationCOO() x = torch.ones(2, 2, 2) m = export_and_import(net, x) - # print(m) + print(m) # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. res1 = net(x) - # res2 = sparse_jit(net, x) + res2 = sparse_jit(net, x) print("torch.sparse") print(res1) - # print("torch.mlir") - # print(res2[0]) - # print(res2[1]) - # print(res2[2]) - # print(res2[3]) - # print(res2[4]) + print("torch.mlir") + print(res2[0]) + print(res2[1]) + print(res2[2]) + print(res2[3]) + print(res2[4]) @run @@ -542,6 +567,8 @@ def forward(self, x): # # CHECK: torch.sparse # CHECK: tensor([ 0., 11., 9., 11., 13., 11., 10., 12.]) +# CHECK: torch.mlir +# CHECK: [ 0. 11. 9. 11. 13. 11. 10. 12.] # def test_sparse_network(): def spike(input): @@ -607,15 +634,24 @@ def forward(self, X): # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. res1 = net(x) - # res2 = sparse_jit(net, x) + res2 = sparse_jit(net, x) print("torch.sparse") print(res1) - # print("torch.mlir") - # print(res2) + print("torch.mlir") + print(res2) @run # +# CHECK-LABEL: test_sparse_feature_scaling +# CHECK: func.func @main( +# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[4,4],f32>) -> !torch.vtensor<[4,4],f32> { +# ... more IR ... +# CHECK: %[[D:.*]] = torch.operator "torch.aten.{{to_sparse|_to_sparse}}" +# CHECK: %[[R:.*]] = torch.aten.{{matmul|mm}} %[[D]], %[[A]] +# CHECK return %[[R]] : !torch.vtensor<[4,4],f32> +# CHECK: } +# # CHECK: torch.sparse # CHECK: tensor({{\[}}[0.3342, 0.5173, 0.0596, 0.0889], # CHECK: [0.1321, 0.2724, 0.2105, 0.3851], @@ -638,7 +674,7 @@ def forward(self, F): torch.manual_seed(0) f = torch.rand(4, 4) m = export_and_import(net, f) - # print(m) + print(m) # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. res1 = net(f)