From 7f188eb824774753ebd169786aa4cc45e99b977c Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 7 Jun 2024 13:58:18 -0700 Subject: [PATCH 01/12] Add f8 types to fx importer (#3434) Missing types for tracing float8 types. --- python/torch_mlir/extras/fx_importer.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 9dcb3c285dc8..16c27c0fa318 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -99,6 +99,10 @@ FloatAttr, BF16Type, ComplexType, + Float8E5M2Type, + Float8E4M3FNType, + Float8E5M2FNUZType, + Float8E4M3FNUZType, F16Type, F32Type, F64Type, @@ -147,6 +151,10 @@ torch.complex32: "complex", torch.complex64: "complex", torch.complex128: "complex", + torch.float8_e5m2: "f8E5M2", + torch.float8_e4m3fn: "f8E4M3FN", + torch.float8_e5m2fnuz: "f8E5M2FNUZ", + torch.float8_e4m3fnuz: "f8E4M3FNUZ", } TORCH_DTYPE_TO_MLIR_TYPE: Dict[torch.dtype, Callable[[], IrType]] = { @@ -165,6 +173,10 @@ torch.complex32: lambda: ComplexType.get(F16Type.get()), torch.complex64: lambda: ComplexType.get(F32Type.get()), torch.complex128: lambda: ComplexType.get(F64Type.get()), + torch.float8_e5m2: lambda: Float8E5M2Type.get(), + torch.float8_e5m2fnuz: lambda: Float8E5M2FNUZType.get(), + torch.float8_e4m3fn: lambda: Float8E4M3FNType.get(), + torch.float8_e4m3fnuz: lambda: Float8E4M3FNUZType.get(), } TORCH_DTYPE_TO_NPY_TYPE = { @@ -203,6 +215,10 @@ # torch.quint8: 13, # torch.qint32 14 torch.bfloat16: 15, + torch.float8_e5m2: 23, + torch.float8_e4m3fn: 24, + torch.float8_e5m2fnuz: 25, + torch.float8_e4m3fnuz: 26, } TORCH_MEMORY_FORMAT_TO_INT = { From 75af64fc121f8da79fdcdb308d3cfc5ebccbe10c Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 7 Jun 2024 13:59:38 -0700 Subject: [PATCH 02/12] [torch] Add support for f8 types for linalg conversion (#3436) Linalg conversion requires mapping for f8 types --- .../Dialect/Torch/Utils/TorchUpstream.h | 45 +++++++++++-------- lib/Dialect/Torch/Utils/Utils.cpp | 16 +++++++ 2 files changed, 43 insertions(+), 18 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h b/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h index 043dd92549b2..3d2c8bb588d7 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h +++ b/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h @@ -86,24 +86,33 @@ enum class TypeKind { // at:: and c10:: parts of the macro are never used within the compiler -- we // only use this for the enum values. #define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(_) \ - _(uint8_t, Byte) /* 0 */ \ - _(int8_t, Char) /* 1 */ \ - _(int16_t, Short) /* 2 */ \ - _(int, Int) /* 3 */ \ - _(int64_t, Long) /* 4 */ \ - _(at::Half, Half) /* 5 */ \ - _(float, Float) /* 6 */ \ - _(double, Double) /* 7 */ \ - _(c10::complex, ComplexHalf) /* 8 */ \ - _(c10::complex, ComplexFloat) /* 9 */ \ - _(c10::complex, ComplexDouble) /* 10 */ \ - _(bool, Bool) /* 11 */ \ - _(c10::qint8, QInt8) /* 12 */ \ - _(c10::quint8, QUInt8) /* 13 */ \ - _(c10::qint32, QInt32) /* 14 */ \ - _(at::BFloat16, BFloat16) /* 15 */ \ - _(c10::quint4x2, QUInt4x2) /* 16 */ \ - _(c10::quint2x4, QUInt2x4) /* 17 */ + _(uint8_t, Byte) /* 0 */ \ + _(int8_t, Char) /* 1 */ \ + _(int16_t, Short) /* 2 */ \ + _(int, Int) /* 3 */ \ + _(int64_t, Long) /* 4 */ \ + _(at::Half, Half) /* 5 */ \ + _(float, Float) /* 6 */ \ + _(double, Double) /* 7 */ \ + _(c10::complex, ComplexHalf) /* 8 */ \ + _(c10::complex, ComplexFloat) /* 9 */ \ + _(c10::complex, ComplexDouble) /* 10 */ \ + _(bool, Bool) /* 11 */ \ + _(c10::qint8, QInt8) /* 12 */ \ + _(c10::quint8, QUInt8) /* 13 */ \ + _(c10::qint32, QInt32) /* 14 */ \ + _(at::BFloat16, BFloat16) /* 15 */ \ + _(c10::quint4x2, QUInt4x2) /* 16 */ \ + _(c10::quint2x4, QUInt2x4) /* 17 */ \ + _(c10::bits1x8, Bits1x8) /* 18 */ \ + _(c10::bits2x4, Bits2x4) /* 19 */ \ + _(c10::bits4x2, Bits4x2) /* 20 */ \ + _(c10::bits8, Bits8) /* 21 */ \ + _(c10::bits16, Bits16) /* 22 */ \ + _(c10::Float8_e5m2, Float8_e5m2) /* 23 */ \ + _(c10::Float8_e4m3fn, Float8_e4m3fn) /* 24 */ \ + _(c10::Float8_e5m2fnuz, Float8_e5m2fnuz) /* 25 */ \ + _(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) /* 26 */ enum class ScalarType : int8_t { #define DEFINE_ENUM(_1, n) n, diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 1c7e6f284f29..388c38b25cb3 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -80,6 +80,14 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) { if (complexElemType.isF64()) return torch_upstream::ScalarType::ComplexDouble; } + if (isa(type)) + return torch_upstream::ScalarType::Float8_e5m2; + if (isa(type)) + return torch_upstream::ScalarType::Float8_e4m3fn; + if (isa(type)) + return torch_upstream::ScalarType::Float8_e5m2fnuz; + if (isa(type)) + return torch_upstream::ScalarType::Float8_e4m3fnuz; llvm::report_fatal_error("unhandled type for getScalarTypeForType"); } Type Torch::getTypeForTorchType( @@ -128,6 +136,14 @@ Torch::getTypeForScalarType(MLIRContext *context, return mlir::ComplexType::get(Float32Type::get(context)); case torch_upstream::ScalarType::ComplexDouble: return mlir::ComplexType::get(Float64Type::get(context)); + case torch_upstream::ScalarType::Float8_e5m2: + return Float8E5M2Type::get(context); + case torch_upstream::ScalarType::Float8_e4m3fn: + return Float8E4M3FNType::get(context); + case torch_upstream::ScalarType::Float8_e5m2fnuz: + return Float8E5M2FNUZType::get(context); + case torch_upstream::ScalarType::Float8_e4m3fnuz: + return Float8E4M3FNUZType::get(context); case torch_upstream::ScalarType::Undefined: return failure(); default: From 689efc89175cc339ca6a1df88be7d24172906c32 Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Sat, 8 Jun 2024 09:36:32 +0800 Subject: [PATCH 03/12] [Torch] fix toBuiltinTensor() (#3415) * Let `toBuiltinTensor()` reflects the original dtype of `!torch.vtensor`. * Backend handles dtype conversion themselves. --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 18 ++++---- lib/Conversion/TorchToLinalg/Linear.cpp | 19 ++++---- lib/Dialect/Torch/IR/TorchOps.cpp | 46 ++++++++----------- lib/Dialect/Torch/IR/TorchTypes.cpp | 11 ++--- .../Transforms/BackendTypeConversion.cpp | 22 ++++++++- 5 files changed, 60 insertions(+), 56 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index b26e1ea3a5f1..b6cc7cdd0ac9 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -737,7 +737,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( std::numeric_limits::lowest())) return failure(); auto minSplatAttr = SplatElementsAttr::get( - resultType.toBuiltinTensor().clone(resultDtype), + resultType.toBuiltinTensor(), rewriter.getFloatAttr(resultDtype, minValue)); min = rewriter.create( binder.getLoc(), resultType, minSplatAttr); @@ -748,7 +748,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( std::numeric_limits::max())) return failure(); auto maxSplatAttr = SplatElementsAttr::get( - resultType.toBuiltinTensor().clone(resultDtype), + resultType.toBuiltinTensor(), rewriter.getFloatAttr(resultDtype, maxValue)); max = rewriter.create( binder.getLoc(), resultType, maxSplatAttr); @@ -861,7 +861,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( if (binder.op->hasAttr("torch.onnx.value_float") && !binder.f32FloatAttr(floatValue, "value_float", 0.0)) { auto splatAttr = - SplatElementsAttr::get(resultType.toBuiltinTensor().clone(dtype), + SplatElementsAttr::get(resultType.toBuiltinTensor(), rewriter.getFloatAttr(dtype, floatValue)); rewriter.replaceOpWithNewOp( binder.op, resultType, splatAttr); @@ -872,7 +872,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( if (binder.op->hasAttr("torch.onnx.value_int") && !binder.s64IntegerAttr(intValue, "value_int", 0)) { auto splatAttr = - SplatElementsAttr::get(resultType.toBuiltinTensor().clone(dtype), + SplatElementsAttr::get(resultType.toBuiltinTensor(), rewriter.getIntegerAttr(dtype, intValue)); rewriter.replaceOpWithNewOp( binder.op, resultType, splatAttr); @@ -932,8 +932,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( for (auto intVal : intValues) { apValues.push_back(APInt(dtype.getIntOrFloatBitWidth(), intVal)); } - auto attr = DenseElementsAttr::get( - resultType.toBuiltinTensor().clone(dtype), apValues); + auto attr = + DenseElementsAttr::get(resultType.toBuiltinTensor(), apValues); rewriter.replaceOpWithNewOp( binder.op, resultType, attr); return success(); @@ -2272,9 +2272,9 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( // Extract the fill value and dtype // ONNX requires value attr to be a tensor if (!attr) { - attr = DenseElementsAttr::get( - resultType.toBuiltinTensor().clone(resultDType), - rewriter.getFloatAttr(resultDType, 0.0)); + attr = + DenseElementsAttr::get(resultType.toBuiltinTensor(), + rewriter.getFloatAttr(resultDType, 0.0)); } // If its a dense resource attr we need to convert to a dense type: diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index aa560402877f..318c2bec361f 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -146,12 +146,11 @@ class ConvertAtenMmOp : public OpConversionPattern { "mismatching contracting dimension for torch.aten.mm")); } - auto resultTy = cast(op.getType()); - auto resultDTy = resultTy.toBuiltinTensor().getElementType(); - Type newResultType = getTypeConverter()->convertType(op.getType()); - Type elementType = cast(newResultType).getElementType(); - auto accumulatorDType = getDefaultAccType(rewriter, resultDTy); - if (accumulatorDType != resultDTy) { + TensorType resultType = + cast(getTypeConverter()->convertType(op.getType())); + Type elementType = resultType.getElementType(); + auto accumulatorDType = getDefaultAccType(rewriter, elementType); + if (accumulatorDType != resultType.getElementType()) { elementType = accumulatorDType; } Value zeroFill = createZeroInitTensor( @@ -197,18 +196,16 @@ class ConvertAtenMmOp : public OpConversionPattern { .getResult(0); } - if (accumulatorDType != resultDTy) { - Type resultElementType = - cast(newResultType).getElementType(); + if (accumulatorDType != resultType.getElementType()) { matmul = torch_to_linalg::convertTensorToElementType( - rewriter, loc, matmul, resultElementType); + rewriter, loc, matmul, resultType.getElementType()); } // When constructed with just dynamic sizes, EmptyOp will have a result // type which has all `?`'s for dimensions, which might not be the result // type of `op`. The constraints on later linalg ops means that the result // of the MatmulOp will have this type too. So cast it to the desired type // so that in the end we have the original result type. - rewriter.replaceOpWithNewOp(op, newResultType, matmul); + rewriter.replaceOpWithNewOp(op, resultType, matmul); return success(); } diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 994722f3ea6f..61a0857a8894 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -1311,7 +1311,7 @@ static OpFoldResult naryFolderHelper(ArrayRef operands, Type ty, return nullptr; auto dty = resultTy.getDtype(); - auto resultBTy = resultTy.toBuiltinTensor().clone(dty); + auto resultBTy = resultTy.toBuiltinTensor(); auto fpTy = dyn_cast(dty); auto intTy = dyn_cast(dty); @@ -1521,7 +1521,7 @@ OpFoldResult AtenEqTensorOp::fold(FoldAdaptor adaptor) { if (!ty || !ty.hasDtype() || !ty.hasSizes()) return nullptr; - auto bty = ty.toBuiltinTensor().clone(ty.getDtype()); + auto bty = ty.toBuiltinTensor(); if (!bty.hasStaticShape()) return nullptr; @@ -1635,7 +1635,6 @@ static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs, return nullptr; auto ctx = lhs.getContext(); - auto resultETy = resultTy.getDtype(); auto tensorETy = cast(lhs.getType()).getElementType(); if (lhs.isSplat()) { if (auto intAttr = dyn_cast(rhs)) { @@ -1647,8 +1646,7 @@ static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs, unsign ? tensorAP.getZExtValue() : tensorAP.getSExtValue(), !unsign); auto resultBool = intFolder(tensorAP, scalarAP, unsign); auto resultAP = IntegerAttr::get(IntegerType::get(ctx, 1), resultBool); - return DenseElementsAttr::get(resultTy.toBuiltinTensor().clone(resultETy), - resultAP); + return DenseElementsAttr::get(resultTy.toBuiltinTensor(), resultAP); } if (auto floatAttr = dyn_cast(rhs)) { @@ -1657,8 +1655,7 @@ static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs, auto resultBool = fpFolder(tensorAP.convertToDouble(), scalarAP.convertToDouble()); auto resultAP = IntegerAttr::get(IntegerType::get(ctx, 1), resultBool); - return DenseElementsAttr::get(resultTy.toBuiltinTensor().clone(resultETy), - resultAP); + return DenseElementsAttr::get(resultTy.toBuiltinTensor(), resultAP); } return nullptr; } @@ -1681,8 +1678,7 @@ static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs, auto resultBool = intFolder(tensorAP, scalarAP, unsign); values.push_back(resultBool); } - return DenseElementsAttr::get(resultTy.toBuiltinTensor().clone(resultETy), - values); + return DenseElementsAttr::get(resultTy.toBuiltinTensor(), values); } if (auto floatAttr = dyn_cast(rhs)) { @@ -1693,8 +1689,7 @@ static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs, fpFolder(tensorAP.convertToDouble(), scalarAP.convertToDouble()); values.push_back(resultBool); } - return DenseElementsAttr::get(resultTy.toBuiltinTensor().clone(resultETy), - values); + return DenseElementsAttr::get(resultTy.toBuiltinTensor(), values); } return nullptr; @@ -1844,7 +1839,7 @@ static OpFoldResult unaryPromoteFolder(DenseElementsAttr operand, if (!fpTy && !intTy) return nullptr; - auto resultBTy = resultTy.toBuiltinTensor().clone(resultTy.getDtype()); + auto resultBTy = resultTy.toBuiltinTensor(); bool splat = operand.isSplat(); bool withinMaxFold = resultBTy.hasStaticShape() && resultBTy.getNumElements() <= kMaxFold; @@ -2192,7 +2187,7 @@ OpFoldResult AtenSelectIntOp::fold(FoldAdaptor adaptor) { return nullptr; auto selfTy = cast(self.getType()); - auto bty = ty.toBuiltinTensor().clone(ty.getDtype()); + auto bty = ty.toBuiltinTensor(); if (!bty.hasStaticShape()) return nullptr; @@ -2656,8 +2651,7 @@ LogicalResult AtenSortOp::fold(FoldAdaptor adaptor, if (!indicesTensorType.hasDtype()) return failure(); - auto indicesType = - indicesTensorType.toBuiltinTensor().clone(indicesTensorType.getDtype()); + auto indicesType = indicesTensorType.toBuiltinTensor(); if (!indicesType || !indicesType.hasStaticShape()) return failure(); @@ -3612,9 +3606,8 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) { return nullptr; if (input && input.isSplat()) - return DenseElementsAttr::get( - outType.toBuiltinTensor().clone(inType.getDtype()), - input.getSplatValue()); + return DenseElementsAttr::get(outType.toBuiltinTensor(), + input.getSplatValue()); int count = 1; for (auto dim : outType.getSizes()) @@ -3652,8 +3645,7 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) { for (int i = begin; i < limit; i += stride) values.push_back(input.getValues()[i]); - return DenseElementsAttr::get( - outType.toBuiltinTensor().clone(inType.getDtype()), values); + return DenseElementsAttr::get(outType.toBuiltinTensor(), values); } // If the input and output shapes are the same we can just fold: @@ -3923,7 +3915,7 @@ OpFoldResult AtenTensorOp::fold(FoldAdaptor adaptor) { if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype()) return nullptr; Type eTy = resultTy.getDtype(); - ShapedType shapedTy = resultTy.toBuiltinTensor().clone(eTy); + ShapedType shapedTy = resultTy.toBuiltinTensor(); SmallVector data; if (matchPattern(getData(), m_TorchListOfConstantInts(data)) && @@ -3944,7 +3936,7 @@ OpFoldResult AtenTensorIntOp::fold(FoldAdaptor adaptor) { if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype()) return nullptr; Type eTy = resultTy.getDtype(); - ShapedType shapedTy = resultTy.toBuiltinTensor().clone(eTy); + ShapedType shapedTy = resultTy.toBuiltinTensor(); int64_t data; if (matchPattern(getT(), m_TorchConstantInt(&data))) { @@ -3964,7 +3956,7 @@ OpFoldResult AtenTensorFloatOp::fold(FoldAdaptor adaptor) { if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype()) return nullptr; Type eTy = resultTy.getDtype(); - ShapedType shapedTy = resultTy.toBuiltinTensor().clone(eTy); + ShapedType shapedTy = resultTy.toBuiltinTensor(); double data; if (matchPattern(getT(), m_TorchConstantFloat(&data))) { @@ -4137,7 +4129,7 @@ OpFoldResult AtenIndexSelectOp::fold(FoldAdaptor adaptor) { : selfAttr.getValues()[indexInt]; auto dty = resultTy.getDtype(); - auto attrTy = resultTy.toBuiltinTensor().clone(dty); + auto attrTy = resultTy.toBuiltinTensor(); if (auto floatAttr = dyn_cast(splattr)) return DenseElementsAttr::get( attrTy, FloatAttr::get(dty, floatAttr.getValueAsDouble())); @@ -4330,7 +4322,7 @@ static Attribute getBroadcastedAttr(Attribute attr, ValueTensorType ty) { if (!valueDense.isSplat()) return nullptr; auto splattr = valueDense.getSplatValue(); - auto attrty = ty.toBuiltinTensor().clone(dty); + auto attrty = ty.toBuiltinTensor(); return DenseElementsAttr::get(attrty, splattr); } @@ -4338,7 +4330,7 @@ static Attribute getBroadcastedAttr(Attribute attr, ValueTensorType ty) { if (!isa(dty)) return nullptr; int64_t intval = intAttr.getInt(); - auto attrty = ty.toBuiltinTensor().clone(dty); + auto attrty = ty.toBuiltinTensor(); return DenseElementsAttr::get(attrty, IntegerAttr::get(dty, intval)); } @@ -4346,7 +4338,7 @@ static Attribute getBroadcastedAttr(Attribute attr, ValueTensorType ty) { if (!isa(dty)) return nullptr; double dblval = fpAttr.getValueAsDouble(); - auto attrty = ty.toBuiltinTensor().clone(dty); + auto attrty = ty.toBuiltinTensor(); return DenseElementsAttr::get(attrty, FloatAttr::get(dty, dblval)); } diff --git a/lib/Dialect/Torch/IR/TorchTypes.cpp b/lib/Dialect/Torch/IR/TorchTypes.cpp index 6735bb37e48b..12aea1589a4d 100644 --- a/lib/Dialect/Torch/IR/TorchTypes.cpp +++ b/lib/Dialect/Torch/IR/TorchTypes.cpp @@ -453,12 +453,7 @@ ValueTensorType::getWithLeastStaticInformation(MLIRContext *context) { } static Type convertDtypeToBuiltinElementType(MLIRContext *context, Type dtype) { - if (auto floatType = dyn_cast(dtype)) { - return dtype; - } else if (auto integerType = dyn_cast(dtype)) { - return IntegerType::get(context, integerType.getWidth(), - IntegerType::Signless); - } else if (isa(dtype)) { + if (isa(dtype)) { return dtype; } @@ -480,11 +475,11 @@ static Type convertDtypeToBuiltinElementType(MLIRContext *context, Type dtype) { TensorType ValueTensorType::toBuiltinTensor() const { if (!hasDtype()) return nullptr; - if (!hasSizes()) - return UnrankedTensorType::get(getDtype()); Type elementType = convertDtypeToBuiltinElementType(getContext(), getDtype()); if (!elementType) return nullptr; + if (!hasSizes()) + return UnrankedTensorType::get(elementType); return RankedTensorType::get(makeShapeLLVMCompatible(getSizes()), elementType, getOptionalSparsity()); } diff --git a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp index deeef0658a52..c4f22715ab34 100644 --- a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp +++ b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp @@ -164,7 +164,18 @@ void mlir::torch::TorchConversion::setupBackendTypeConversion( ConversionTarget &target, TypeConverter &typeConverter) { auto valueTensorTypeConversion = [](Torch::ValueTensorType type) -> std::optional { - return type.toBuiltinTensor(); + auto builtinType = type.toBuiltinTensor(); + if (!builtinType) + return std::nullopt; + + // convert any integer type to signless + if (type.getDtype().isInteger()) { + return builtinType.clone(IntegerType::get( + builtinType.getContext(), type.getDtype().getIntOrFloatBitWidth(), + IntegerType::Signless)); + } + + return builtinType; }; setupValueTensorToBuiltinTensorConversion(target, typeConverter, valueTensorTypeConversion); @@ -180,9 +191,18 @@ void mlir::torch::TorchConversion::setupBackendTypeConversionForStablehlo( auto valueTensorTypeConversion = [](Torch::ValueTensorType type) -> std::optional { auto builtinType = type.toBuiltinTensor(); + if (!builtinType) + return std::nullopt; + + // convert signed integer type to signless, keep unsigned as unsigned if (type.getDtype().isUnsignedInteger()) { return builtinType.clone(type.getDtype()); + } else if (type.getDtype().isSignedInteger()) { + return builtinType.clone(IntegerType::get( + builtinType.getContext(), type.getDtype().getIntOrFloatBitWidth(), + IntegerType::Signless)); } + return builtinType; }; setupValueTensorToBuiltinTensorConversion(target, typeConverter, From d35b6b412aa7252eb377967f4feb2a753ec1a7fb Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Sat, 8 Jun 2024 09:58:11 +0530 Subject: [PATCH 04/12] [ONNX] Add OnnxToTorch Lowering for Sequence Ops (#3425) This commit adds the lowering for SequenceAt, SequenceEmpty, SequenceInsert, SequenceErase op Signed-Off By: Vivek Khandelwal --- .../Conversion/TorchOnnxToTorch/Patterns.h | 12 ++ .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 138 +++++++++++++ .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 189 ++++++++++++++++++ 3 files changed, 339 insertions(+) diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h index 0de85f4eebe5..f296b6dfee5c 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h @@ -110,6 +110,18 @@ struct OpBinder { return success(); } + ParseResult tensorListOperandAtIndex(Value &valueIdx, int64_t idx) { + if (idx >= op->getNumOperands()) + return failure(); + valueIdx = op->getOperand(idx); + auto tt = dyn_cast(valueIdx.getType()); + if (!tt) + return failure(); + if (!toValidTensorType(tt.getContainedType())) + return failure(); + return success(); + } + ParseResult tensorListResultType(Torch::ListType &type0) { if (op->getNumResults() != 1) return failure(); diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 1eb5bcc1c67c..18399aa2d4d2 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -3120,7 +3120,145 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( rewriter.replaceOpWithNewOp( binder.op, resultType, inputLTNegLambd, inputPlusBias, inputSubBiasOrZero); + return success(); + }); + patterns.onOp("SequenceAt", 11, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value inputSequence, position; + if (binder.tensorListOperandAtIndex(inputSequence, 0) || + binder.tensorOperandAtIndex(position, 1) || + binder.tensorResultType(resultType)) + return failure(); + + Value index = rewriter.create( + binder.getLoc(), rewriter.getType(), + position); + rewriter.replaceOpWithNewOp( + binder.op, resultType, inputSequence, index); + return success(); + }); + patterns.onOp( + "SequenceEmpty", 11, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ListType resultType; + int64_t dtypeIntOnnx; + if (binder.s64IntegerAttr(dtypeIntOnnx, "dtype", 1) || + binder.tensorListResultType(resultType)) + return failure(); + + std::optional dtypeIntTorch = + onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx); + if (!dtypeIntTorch.has_value()) { + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented support for the given dtype conversion"); + } + Value constDtype = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value())); + + Value shapeList = createConstantIntList(binder, rewriter, {}); + Value cstNone = rewriter.create(binder.getLoc()); + + Value self = rewriter.create( + binder.op->getLoc(), resultType.getContainedType(), shapeList, + /*dtype=*/constDtype, + /*layout=*/cstNone, + /*device=*/cstNone, /*pinMemory=*/cstNone, + /*memoryFormat=*/cstNone); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, llvm::SmallVector{self}); + return success(); + }); + patterns.onOp( + "SequenceErase", 11, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ListType resultType; + Value inputSequence, position; + if (binder.tensorListOperandAtIndex(inputSequence, 0) || + binder.tensorListResultType(resultType)) + return failure(); + + Value length = rewriter.create( + binder.getLoc(), rewriter.getType(), inputSequence); + + Value cstNone = rewriter.create(binder.getLoc()); + Value cstOne = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + if (binder.op->getNumOperands() == 1) { + // If True, it means that the `position` arg is missing and + // the last tensor from the list has to be erased. + Value lengthMinusOne = rewriter.create( + binder.getLoc(), length, cstOne); + rewriter.replaceOpWithNewOp( + binder.op, resultType, inputSequence, /*start=*/cstNone, + /*end=*/lengthMinusOne, /*step=*/cstOne); + return success(); + } + + if (binder.tensorOperandAtIndex(position, 1)) + return failure(); + + Value positionInt = rewriter.create( + binder.getLoc(), rewriter.getType(), position); + // Handling negative position value. + Value cstZero = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + Value isPositionNegative = rewriter.create( + binder.getLoc(), positionInt, cstZero); + isPositionNegative = rewriter.create( + binder.getLoc(), isPositionNegative); + Value finalOffset = rewriter.create( + binder.getLoc(), isPositionNegative, length); + positionInt = rewriter.create( + binder.getLoc(), positionInt, finalOffset); + + Value listBeforePosition = rewriter.create( + binder.getLoc(), resultType, inputSequence, /*start=*/cstNone, + /*end=*/positionInt, /*step=*/cstOne); + Value positionPlusOne = rewriter.create( + binder.getLoc(), positionInt, cstOne); + Value listAfterPosition = rewriter.create( + binder.getLoc(), resultType, inputSequence, + /*start=*/positionPlusOne, + /*end=*/length, /*step=*/cstOne); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, listBeforePosition, listAfterPosition); + return success(); + }); + patterns.onOp( + "SequenceInsert", 11, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ListType resultType; + Value inputSequence, position, insertValue; + if (binder.tensorListOperandAtIndex(inputSequence, 0) || + binder.tensorOperandAtIndex(insertValue, 1) || + binder.tensorListResultType(resultType)) + return failure(); + + if (binder.op->getNumOperands() == 1) { + // If True, it means that the `position` arg is missing and + // the tensor has to be inserted at the end of the list. + Value length = rewriter.create( + binder.getLoc(), rewriter.getType(), + inputSequence); + rewriter.replaceOpWithNewOp( + binder.op, inputSequence, /*idx=*/length, + /*el=*/insertValue); + return success(); + } + + if (binder.tensorOperandAtIndex(position, 2)) + return failure(); + Value positionInt = rewriter.create( + binder.getLoc(), rewriter.getType(), position); + rewriter.create(binder.getLoc(), inputSequence, + /*idx=*/positionInt, + /*el=*/insertValue); + rewriter.replaceOp(binder.op, inputSequence); return success(); }); } diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index eb5a9f7cac4a..317a3aeb155f 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -2339,3 +2339,192 @@ func.func @test_shrink_hard(%arg0: !torch.vtensor<[5],f32>) -> !torch.vtensor<[5 %0 = torch.operator "onnx.Shrink"(%arg0) {torch.onnx.lambd = 1.500000e+00 : f32} : (!torch.vtensor<[5],f32>) -> !torch.vtensor<[5],f32> return %0 : !torch.vtensor<[5],f32> } + +// ----- + +// CHECK-LABEL: func.func @test_sequence_at +func.func @test_sequence_at(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2,3,4],f32>, %arg2: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[VTENSOR_1:.*]] = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[SEQUENCE:.*]] = torch.prim.ListConstruct %arg0, %arg1, %arg2 : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + // CHECK: %[[LENGTH:.*]] = torch.aten.len.t %[[SEQUENCE]] : !torch.list> -> !torch.int + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[ITEM:.*]] = torch.aten.item %[[VTENSOR]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[CMP:.*]] = torch.aten.lt.int %[[ITEM]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[COND:.*]] = torch.aten.Int.bool %[[CMP]] : !torch.bool -> !torch.int + // CHECK: %[[OFFSET:.*]] = torch.aten.mul.int %[[COND]], %[[LENGTH]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[POSITION:.*]] = torch.aten.add.int %[[ITEM]], %[[OFFSET]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[NONE]], %[[POSITION]], %[[INT1]] : !torch.list>, !torch.none, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[POSITION]], %[[INT1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE_1:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[ADD]], %[[LENGTH]], %[[INT1]] : !torch.list>, !torch.int, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[CONCAT_LIST:.*]] = torch.aten.add.t %[[SLICE]], %[[SLICE_1]] : !torch.list>, !torch.list> -> !torch.list> + // CHECK: %[[ITEM_0:.*]] = torch.aten.item %[[VTENSOR_1]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[RESULT:.*]] = torch.aten.__getitem__.t %[[CONCAT_LIST]], %[[ITEM_0]] : !torch.list>, !torch.int -> !torch.vtensor<[2,3,4],f32> + // CHECK: return %[[RESULT]] : !torch.vtensor<[2,3,4],f32> + %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor} : () -> !torch.vtensor<[],si64> + %1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor} : () -> !torch.vtensor<[],si64> + %2 = torch.operator "onnx.SequenceConstruct"(%arg0, %arg1, %arg2) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + %3 = torch.operator "onnx.SequenceErase"(%2, %0) : (!torch.list>, !torch.vtensor<[],si64>) -> !torch.list> + %4 = torch.operator "onnx.SequenceAt"(%3, %1) : (!torch.list>, !torch.vtensor<[],si64>) -> !torch.vtensor<[2,3,4],f32> + return %4 : !torch.vtensor<[2,3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_sequence_insert +func.func @test_sequence_insert(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2,3,4],f32>, %arg2: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<-3> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[VTENSOR_1:.*]] = torch.vtensor.literal(dense<-1> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[VTENSOR_2:.*]] = torch.vtensor.literal(dense<-1> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[SEQUENCE:.*]] = torch.prim.ListConstruct %arg0, %arg1, %arg2 : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + // CHECK: %[[LENGTH:.*]] = torch.aten.len.t %[[SEQUENCE]] : !torch.list> -> !torch.int + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[ITEM:.*]] = torch.aten.item %[[VTENSOR]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[CMP:.*]] = torch.aten.lt.int %[[ITEM]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[COND:.*]] = torch.aten.Int.bool %[[CMP]] : !torch.bool -> !torch.int + // CHECK: %[[OFFSET:.*]] = torch.aten.mul.int %[[COND]], %[[LENGTH]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[POSITION:.*]] = torch.aten.add.int %[[ITEM]], %[[OFFSET]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[NONE]], %[[POSITION]], %[[INT1]] : !torch.list>, !torch.none, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[POSITION]], %[[INT1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE_1:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[ADD]], %[[LENGTH]], %[[INT1]] : !torch.list>, !torch.int, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[CONCAT_LIST:.*]] = torch.aten.add.t %[[SLICE]], %[[SLICE_1]] : !torch.list>, !torch.list> -> !torch.list> + // CHECK: %[[ITEM_0:.*]] = torch.aten.item %[[VTENSOR_1]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: torch.aten.insert.t %[[CONCAT_LIST]], %[[ITEM_0]], %arg0 : !torch.list>, !torch.int, !torch.vtensor<[2,3,4],f32> + // CHECK: %[[ITEM_1:.*]] = torch.aten.item %[[VTENSOR_2]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[RESULT:.*]] = torch.aten.__getitem__.t %[[CONCAT_LIST]], %[[ITEM_1]] : !torch.list>, !torch.int -> !torch.vtensor<[2,3,4],f32> + // CHECK: return %[[RESULT]] : !torch.vtensor<[2,3,4],f32> + %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<-3> : tensor} : () -> !torch.vtensor<[],si64> + %1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<-1> : tensor} : () -> !torch.vtensor<[],si64> + %2 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<-1> : tensor} : () -> !torch.vtensor<[],si64> + %3 = torch.operator "onnx.SequenceConstruct"(%arg0, %arg1, %arg2) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + %4 = torch.operator "onnx.SequenceErase"(%3, %0) : (!torch.list>, !torch.vtensor<[],si64>) -> !torch.list> + %5 = torch.operator "onnx.SequenceInsert"(%4, %arg0, %1) : (!torch.list>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[],si64>) -> !torch.list> + %6 = torch.operator "onnx.SequenceAt"(%5, %2) : (!torch.list>, !torch.vtensor<[],si64>) -> !torch.vtensor<[2,3,4],f32> + return %6 : !torch.vtensor<[2,3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_sequence_erase_at_beginning +func.func @test_sequence_erase_at_beginning(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2,3,4],f32>, %arg2: !torch.vtensor<[2,3,4],f32>) -> !torch.list> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[SEQUENCE:.*]] = torch.prim.ListConstruct %arg0, %arg1, %arg2 : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + // CHECK: %[[LENGTH:.*]] = torch.aten.len.t %[[SEQUENCE]] : !torch.list> -> !torch.int + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[ITEM:.*]] = torch.aten.item %[[VTENSOR]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[CMP:.*]] = torch.aten.lt.int %[[ITEM]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[COND:.*]] = torch.aten.Int.bool %[[CMP]] : !torch.bool -> !torch.int + // CHECK: %[[OFFSET:.*]] = torch.aten.mul.int %[[COND]], %[[LENGTH]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[POSITION:.*]] = torch.aten.add.int %[[ITEM]], %[[OFFSET]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[NONE]], %[[POSITION]], %[[INT1]] : !torch.list>, !torch.none, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[POSITION]], %[[INT1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE_1:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[ADD]], %[[LENGTH]], %[[INT1]] : !torch.list>, !torch.int, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[RESULT:.*]] = torch.aten.add.t %[[SLICE]], %[[SLICE_1]] : !torch.list>, !torch.list> -> !torch.list> + // CHECK: return %[[RESULT]] : !torch.list> + %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor} : () -> !torch.vtensor<[],si64> + %3 = torch.operator "onnx.SequenceConstruct"(%arg0, %arg1, %arg2) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + %4 = torch.operator "onnx.SequenceErase"(%3, %0) : (!torch.list>, !torch.vtensor<[],si64>) -> !torch.list> + return %4 : !torch.list> +} + +// ----- + +// CHECK-LABEL: func.func @test_sequence_erase_at_end +func.func @test_sequence_erase_at_end(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2,3,4],f32>, %arg2: !torch.vtensor<[2,3,4],f32>) -> !torch.list> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<2> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[SEQUENCE:.*]] = torch.prim.ListConstruct %arg0, %arg1, %arg2 : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + // CHECK: %[[LENGTH:.*]] = torch.aten.len.t %[[SEQUENCE]] : !torch.list> -> !torch.int + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[ITEM:.*]] = torch.aten.item %[[VTENSOR]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[CMP:.*]] = torch.aten.lt.int %[[ITEM]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[COND:.*]] = torch.aten.Int.bool %[[CMP]] : !torch.bool -> !torch.int + // CHECK: %[[OFFSET:.*]] = torch.aten.mul.int %[[COND]], %[[LENGTH]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[POSITION:.*]] = torch.aten.add.int %[[ITEM]], %[[OFFSET]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[NONE]], %[[POSITION]], %[[INT1]] : !torch.list>, !torch.none, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[POSITION]], %[[INT1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE_1:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[ADD]], %[[LENGTH]], %[[INT1]] : !torch.list>, !torch.int, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[RESULT:.*]] = torch.aten.add.t %[[SLICE]], %[[SLICE_1]] : !torch.list>, !torch.list> -> !torch.list> + // CHECK: return %[[RESULT]] : !torch.list> + %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<2> : tensor} : () -> !torch.vtensor<[],si64> + %3 = torch.operator "onnx.SequenceConstruct"(%arg0, %arg1, %arg2) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + %4 = torch.operator "onnx.SequenceErase"(%3, %0) : (!torch.list>, !torch.vtensor<[],si64>) -> !torch.list> + return %4 : !torch.list> +} + +// ----- + +// CHECK-LABEL: func.func @test_sequence_erase_negative_idx +func.func @test_sequence_erase_negative_idx(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2,3,4],f32>, %arg2: !torch.vtensor<[2,3,4],f32>) -> !torch.list> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<-2> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[SEQUENCE:.*]] = torch.prim.ListConstruct %arg0, %arg1, %arg2 : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + // CHECK: %[[LENGTH:.*]] = torch.aten.len.t %[[SEQUENCE]] : !torch.list> -> !torch.int + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[ITEM:.*]] = torch.aten.item %[[VTENSOR]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[CMP:.*]] = torch.aten.lt.int %[[ITEM]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[COND:.*]] = torch.aten.Int.bool %[[CMP]] : !torch.bool -> !torch.int + // CHECK: %[[OFFSET:.*]] = torch.aten.mul.int %[[COND]], %[[LENGTH]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[POSITION:.*]] = torch.aten.add.int %[[ITEM]], %[[OFFSET]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[NONE]], %[[POSITION]], %[[INT1]] : !torch.list>, !torch.none, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[POSITION]], %[[INT1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE_1:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[ADD]], %[[LENGTH]], %[[INT1]] : !torch.list>, !torch.int, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[RESULT:.*]] = torch.aten.add.t %[[SLICE]], %[[SLICE_1]] : !torch.list>, !torch.list> -> !torch.list> + // CHECK: return %[[RESULT]] : !torch.list> + %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<-2> : tensor} : () -> !torch.vtensor<[],si64> + %3 = torch.operator "onnx.SequenceConstruct"(%arg0, %arg1, %arg2) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + %4 = torch.operator "onnx.SequenceErase"(%3, %0) : (!torch.list>, !torch.vtensor<[],si64>) -> !torch.list> + return %4 : !torch.list> +} + +// ----- + +// CHECK-LABEL: func.func @test_sequence_erase_empty +func.func @test_sequence_erase_empty() -> !torch.list> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[INT6:.*]] = torch.constant.int 6 + // CHECK: %[[SHAPE_LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[EMPTY_TENSOR:.*]] = torch.aten.empty.memory_format %[[SHAPE_LIST]], %[[INT6]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> + // CHECK: %[[SEQUENCE:.*]] = torch.prim.ListConstruct %[[EMPTY_TENSOR]] : (!torch.vtensor<[],f32>) -> !torch.list> + // CHECK: %[[LENGTH:.*]] = torch.aten.len.t %[[SEQUENCE]] : !torch.list> -> !torch.int + // CHECK: %[[NONE_0:.*]] = torch.constant.none + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[ITEM:.*]] = torch.aten.item %[[VTENSOR]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[CMP:.*]] = torch.aten.lt.int %[[ITEM]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[COND:.*]] = torch.aten.Int.bool %[[CMP]] : !torch.bool -> !torch.int + // CHECK: %[[OFFSET:.*]] = torch.aten.mul.int %[[COND]], %[[LENGTH]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[POSITION:.*]] = torch.aten.add.int %[[ITEM]], %[[OFFSET]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[NONE_0]], %[[POSITION]], %[[INT1]] : !torch.list>, !torch.none, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[POSITION]], %[[INT1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE_1:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[ADD]], %[[LENGTH]], %[[INT1]] : !torch.list>, !torch.int, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[RESULT:.*]] = torch.aten.add.t %[[SLICE]], %[[SLICE_1]] : !torch.list>, !torch.list> -> !torch.list> + // CHECK: return %[[RESULT]] : !torch.list> + %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor} : () -> !torch.vtensor<[],si64> + %1 = torch.operator "onnx.SequenceEmpty"() : () -> !torch.list> + %4 = torch.operator "onnx.SequenceErase"(%1, %0) : (!torch.list>, !torch.vtensor<[],si64>) -> !torch.list> + return %4 : !torch.list> +} + +// ----- + +// CHECK-LABEL: func.func @test_sequence_empty +func.func @test_sequence_empty() -> !torch.list> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT6:.*]] = torch.constant.int 6 + // CHECK: %[[SHAPE_LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[EMPTY_TENSOR:.*]] = torch.aten.empty.memory_format %[[SHAPE_LIST]], %[[INT6]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> + // CHECK: %[[RESULT:.*]] = torch.prim.ListConstruct %[[EMPTY_TENSOR]] : (!torch.vtensor<[],f32>) -> !torch.list> + // CHECK: return %[[RESULT]] : !torch.list> + %0 = torch.operator "onnx.SequenceEmpty"() : () -> !torch.list> + return %0 : !torch.list> +} From 5bc626465b0daaad68ad3d6fb1f6fdf4746dfef4 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Sun, 9 Jun 2024 12:07:20 +0530 Subject: [PATCH 05/12] [ONNX] Lower Onnx.Concat lowering version (#3437) Signed-Off By: Vivek Khandelwal --- lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index b6cc7cdd0ac9..31deadcafb7f 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -829,7 +829,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( return success(); }); patterns.onOp( - "Concat", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + "Concat", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; SmallVector tensors; int64_t dim; From 7e0e23c66820d1db548103acbdf1337f701dc5a3 Mon Sep 17 00:00:00 2001 From: Sambhav Jain Date: Sun, 9 Jun 2024 00:32:49 -0700 Subject: [PATCH 06/12] Test custom op import with symbolic shapes (#3431) Tests the basic constructs of registering a custom op and its abstract implementations (with FakeTensors) in python, going through TorchDynamo export, followed by importing the shape expressions in the Torch dialect. Also fixes the importer were previously the symbolic bind op insertion was not gated in one place. --- python/torch_mlir/extras/fx_importer.py | 3 +- test/python/fx_importer/custom_op_test.py | 86 +++++++++++++++++++++++ 2 files changed, 88 insertions(+), 1 deletion(-) create mode 100644 test/python/fx_importer/custom_op_test.py diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 16c27c0fa318..2a73325c7d76 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -1445,7 +1445,8 @@ def import_nodes( operands = [self._import_argument(loc, arg) for arg in node.args[0]] func_dialect.ReturnOp(operands, loc=loc) - self._create_bind_symbolic_shape_ops(loc, node) + if import_symbolic_shape_expressions: + self._create_bind_symbolic_shape_ops(loc, node) def _promote_symbolic_scalar_int_float(self, loc, graph, param): temp_target = torch.ops.aten.Float.Scalar diff --git a/test/python/fx_importer/custom_op_test.py b/test/python/fx_importer/custom_op_test.py new file mode 100644 index 000000000000..dbbc5ba057af --- /dev/null +++ b/test/python/fx_importer/custom_op_test.py @@ -0,0 +1,86 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +# RUN: %PYTHON %s | FileCheck %s + +import torch +import torch.nn as nn +from torch.export import Dim +from torch.library import Library, impl, impl_abstract + +from torch_mlir import fx + + +def run(f): + print(f"{f.__name__}") + print("-" * len(f.__name__)) + f() + print() + + +@run +# CHECK-LABEL: test_tanh_sigmoid_cat_custom_op +# CHECK: func.func @main( +# CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>, +# CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>, +# CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> { +# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int +# CHECK: %[[S1:.+]] = torch.symbolic_int "s1" {min_val = {{[0-9]+}}, max_val = 100} : !torch.int +# CHECK: %[[S2:.+]] = torch.symbolic_int "s3" {min_val = {{[0-9]+}}, max_val = 50} : !torch.int +# CHECK: %[[S3:.+]] = torch.symbolic_int "s5" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]], %[[S2]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[ARG2]], [%[[S0]], %[[S3]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: %[[OP:.+]] = torch.operator "torch.my_custom_library.tanh_sigmoid_cat_op"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[OP]], [%[[S0]], %[[S1]], %[[S2]], %[[S3]]], affine_map<()[s0, s1, s2, s3] -> (s0, s2 + s3 + s1 * 2, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: return %[[OP]] : !torch.vtensor<[?,?,3],f32> +def test_tanh_sigmoid_cat_custom_op(): + + m = Library("my_custom_library", "DEF") + m.define("tanh_sigmoid_cat_op(Tensor x, Tensor y, Tensor z) -> Tensor") + + @impl(m, "tanh_sigmoid_cat_op", "CompositeExplicitAutograd") + def custom_op(x, y, z): + a = torch.tanh(x) + b = torch.sigmoid(y) + return torch.cat((a, a, b, z), dim=1) + + @impl_abstract("my_custom_library::tanh_sigmoid_cat_op") + def custom_op_meta(x, y, z): + result = custom_op(x, y, z) + return torch.empty_like(result) + + class TanhSigmoidCatCustomOp(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y, z): + return torch.ops.my_custom_library.tanh_sigmoid_cat_op(x, y, z) + + # Sample inputs + x = torch.randn(5, 2, 3) + y = torch.randn(5, 6, 3) + z = torch.randn(5, 4, 3) + + # Dynamic dim constraints + dim_n = Dim("n", min=5, max=10) + dim_x1 = Dim("x1", max=100) + dim_y1 = Dim("y1", max=50) + dim_z1 = Dim("z1") + dynamic_shapes = { + "x": {0: dim_n, 1: dim_x1}, + "y": {0: dim_n, 1: dim_y1}, + "z": {0: dim_n, 1: dim_z1}, + } + + m = fx.export_and_import( + TanhSigmoidCatCustomOp(), + x, + y, + z, + dynamic_shapes=dynamic_shapes, + import_symbolic_shape_expressions=True, + ) + print(m) From d77bab37d1473dc48340e4807fd382d64c3cd5eb Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Mon, 10 Jun 2024 11:19:32 -0700 Subject: [PATCH 07/12] [torch-mlir][sparse] re-enable all sparse tests (#3444) this fixes the following issue: https://github.com/llvm/torch-mlir/issues/3418 --- lib/Conversion/TorchToLinalg/DataMovement.cpp | 2 + test/python/fx_importer/sparse_test.py | 64 +++++++++++++++---- 2 files changed, 52 insertions(+), 14 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index b9b0fb0ae5d7..dc8b5d431002 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -2579,6 +2579,8 @@ class ConvertSparseOperatorOp : public OpConversionPattern { SmallVector ConvertSparseOperatorOp::legalizedNames = { "torch.aten._to_dense", "torch.aten._to_sparse", "torch.aten._to_csr", "torch.aten._to_csc", "torch.aten._to_bsr", "torch.aten._to_bsc", + "torch.aten.to_dense", "torch.aten.to_sparse", "torch.aten.to_csr", + "torch.aten.to_csc", "torch.aten.to_bsr", "torch.aten.to_bsc", }; } // namespace diff --git a/test/python/fx_importer/sparse_test.py b/test/python/fx_importer/sparse_test.py index 41872b77e928..7c7198ef6f61 100644 --- a/test/python/fx_importer/sparse_test.py +++ b/test/python/fx_importer/sparse_test.py @@ -125,7 +125,7 @@ def sparse_export( # Zero preserving elt-wise unary op. if opname in {"abs", "neg", "relu", "sin"}: node.meta["sparsity"] = node.args[0].meta.get("sparsity", None) - elif opname == "_to_sparse": + elif opname == "_to_sparse" or opname == "to_sparse": dim = len(node.meta.get("val").shape) node.meta["sparsity"] = SparsityMeta( torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64 @@ -339,6 +339,14 @@ def forward(self, x, v): @run # +# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa)), posWidth = 64, crdWidth = 64 }> +# CHECK: func.func @main( +# CHECK-SAME: %[[A:.*0]]: !torch.vtensor<[8,8],f32,#[[$COO]]>, +# CHECK-SAME: %[[B:.*1]]: !torch.vtensor<[8,8],f32>) -> !torch.vtensor<[8,8],f32> { +# CHECK: %[[R:.*]] = torch.aten.{{matmul|mm}} %[[A]], %[[B]] : !torch.vtensor<[8,8],f32,#[[$COO]]>, !torch.vtensor<[8,8],f32> -> !torch.vtensor<[8,8],f32> +# CHECK: return %[[R]] : !torch.vtensor<[8,8],f32> +# CHECK: } +## # CHECK: torch.sparse # CHECK: tensor({{\[}}[8., 8., 8., 8., 8., 8., 8., 8.], # CHECK-COUNT-6: [8., 8., 8., 8., 8., 8., 8., 8.], @@ -360,7 +368,7 @@ def forward(self, x, y): dense_input = torch.ones(8, 8) sparse_input = dense_input.to_sparse_coo() m = export_and_import(net, sparse_input, dense_input) - # print(m) + print(m) # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. res1 = net(sparse_input, dense_input) @@ -500,12 +508,29 @@ def forward(self, x): @run # +# CHECK-LABEL: test_sparse_activation +# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton(nonunique, soa), d2 : singleton(soa)), posWidth = 64, crdWidth = 64 }> +# CHECK: func.func @main( +# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[2,2,2],f32,#[[$COO]]> { +# CHECK: %[[N1:.*]] = torch.constant.none +# CHECK: %[[N2:.*]] = torch.constant.none +# CHECK: %[[N3:.*]] = torch.constant.none +# CHECK: %[[R:.*]] = torch.operator "torch.aten.{{to_sparse|_to_sparse}}"(%[[A]], %[[N1]], %[[N2]], %[[N3]]) : (!torch.vtensor<[2,2,2],f32>, !torch.none, !torch.none, !torch.none) -> !torch.vtensor<[2,2,2],f32,#[[$COO]]> +# CHECK: return %[[R]] : !torch.vtensor<[2,2,2],f32,#[[$COO]]> +# CHECK: } +# # CHECK: torch.sparse # CHECK: tensor(indices=tensor({{\[}}[0, 0, 0, 0, 1, 1, 1, 1], # CHECK: [0, 0, 1, 1, 0, 0, 1, 1], # CHECK: [0, 1, 0, 1, 0, 1, 0, 1]{{\]}}), # CHECK: values=tensor([1., 1., 1., 1., 1., 1., 1., 1.]), # CHECK: size=(2, 2, 2), nnz=8, layout=torch.sparse_coo) +# CHECK: torch.mlir +# CHECK: [0 8] +# CHECK: [0 0 0 0 1 1 1 1] +# CHECK: [0 0 1 1 0 0 1 1] +# CHECK: [0 1 0 1 0 1 0 1] +# CHECK: [1. 1. 1. 1. 1. 1. 1. 1.] # def test_sparse_activation(): class SparseActivationCOO(torch.nn.Module): @@ -515,19 +540,19 @@ def forward(self, x): net = SparseActivationCOO() x = torch.ones(2, 2, 2) m = export_and_import(net, x) - # print(m) + print(m) # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. res1 = net(x) - # res2 = sparse_jit(net, x) + res2 = sparse_jit(net, x) print("torch.sparse") print(res1) - # print("torch.mlir") - # print(res2[0]) - # print(res2[1]) - # print(res2[2]) - # print(res2[3]) - # print(res2[4]) + print("torch.mlir") + print(res2[0]) + print(res2[1]) + print(res2[2]) + print(res2[3]) + print(res2[4]) @run @@ -542,6 +567,8 @@ def forward(self, x): # # CHECK: torch.sparse # CHECK: tensor([ 0., 11., 9., 11., 13., 11., 10., 12.]) +# CHECK: torch.mlir +# CHECK: [ 0. 11. 9. 11. 13. 11. 10. 12.] # def test_sparse_network(): def spike(input): @@ -607,15 +634,24 @@ def forward(self, X): # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. res1 = net(x) - # res2 = sparse_jit(net, x) + res2 = sparse_jit(net, x) print("torch.sparse") print(res1) - # print("torch.mlir") - # print(res2) + print("torch.mlir") + print(res2) @run # +# CHECK-LABEL: test_sparse_feature_scaling +# CHECK: func.func @main( +# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[4,4],f32>) -> !torch.vtensor<[4,4],f32> { +# ... more IR ... +# CHECK: %[[D:.*]] = torch.operator "torch.aten.{{to_sparse|_to_sparse}}" +# CHECK: %[[R:.*]] = torch.aten.{{matmul|mm}} %[[D]], %[[A]] +# CHECK return %[[R]] : !torch.vtensor<[4,4],f32> +# CHECK: } +# # CHECK: torch.sparse # CHECK: tensor({{\[}}[0.3342, 0.5173, 0.0596, 0.0889], # CHECK: [0.1321, 0.2724, 0.2105, 0.3851], @@ -638,7 +674,7 @@ def forward(self, F): torch.manual_seed(0) f = torch.rand(4, 4) m = export_and_import(net, f) - # print(m) + print(m) # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. res1 = net(f) From e07a0bfc5464c3f2cd3f3a7e2a581f55ea99e176 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Mon, 10 Jun 2024 20:59:29 +0200 Subject: [PATCH 08/12] onnx.resize: Add support for coordTfMode "half_pixel" (#3441) half_pixel is also the default mode used by ONNX, see https://onnx.ai/onnx/operators/onnx__Resize.html --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 5 ++- .../TorchToLinalg/Uncategorized.cpp | 14 ++++++- .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 13 ++++++ test/Conversion/TorchToLinalg/resize.mlir | 41 +++++++++++++++++++ 4 files changed, 70 insertions(+), 3 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 18399aa2d4d2..67370567ad6b 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2823,10 +2823,11 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, "unimplemented: coordinate transformation mode: " "tf_crop_and_resize"); - if (mode == "nearest" && coordTfMode != "asymmetric") { + if (mode == "nearest" && coordTfMode != "asymmetric" && + coordTfMode != "half_pixel") { return rewriter.notifyMatchFailure( binder.op, "unimplemented: support not present for coord tf mode " - "except asymmetric"); + "except asymmetric and half_pixel"); } unsigned rank = dyn_cast(operands[0].getType()) diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index b6fc225c42fe..a1c3003e32a4 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -2631,7 +2631,17 @@ static Value NearestInterpolate(OpBuilder &b, Location loc, Value outInt = b.create(loc, b.getI64Type(), outIndex); Value outFP = b.create(loc, b.getF32Type(), outInt); - Value proj = b.create(loc, outFP, scale); + Value proj; + if (coordStr.empty() || coordStr == "_asymmetric") { + proj = b.create(loc, outFP, scale); + } else if (coordStr == "_half_pixel") { + Value cstHalf = b.create(loc, b.getF32FloatAttr(0.5)); + Value add = b.create(loc, outFP, cstHalf); + Value div = b.create(loc, add, scale); + proj = b.create(loc, div, cstHalf); + } else { + llvm_unreachable("Unsupported coordination transformation mode"); + } Value nearestFP; // get nearest pixel using floor @@ -2655,6 +2665,8 @@ static Value NearestInterpolate(OpBuilder &b, Location loc, nearestFP = b.create(loc, cmp, ceil, floor); } else if (nearestMode == "ceil") { nearestFP = b.create(loc, proj); + } else { + llvm_unreachable("Unsupported nearest mode"); } Value nearestInt = b.create(loc, b.getI64Type(), nearestFP); diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 317a3aeb155f..ae47b49b06f3 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -2183,6 +2183,19 @@ func.func @test_sce_mean_3d_log_prob(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1: // ----- +// CHECK-LABEL: func.func @test_resize_sizes_nearest +func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %none = torch.constant.none + // CHECK: %[[STR:.+]] = torch.constant.str "nearest_half_pixel,round_prefer_floor" + // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %[[STR]], %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) { + torch.onnx.coordinate_transformation_mode = "half_pixel", + torch.onnx.mode = "nearest"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> + return %0 : !torch.vtensor<[?,?,?,?],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_resize_sizes_linear func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?], f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { diff --git a/test/Conversion/TorchToLinalg/resize.mlir b/test/Conversion/TorchToLinalg/resize.mlir index 8d714fda0c5f..6847d25736f1 100644 --- a/test/Conversion/TorchToLinalg/resize.mlir +++ b/test/Conversion/TorchToLinalg/resize.mlir @@ -155,3 +155,44 @@ func.func @test_resize_nearest_3d(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: %7 = torch.aten.__interpolate.size_list_scale_list %arg0, %6, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[?,?,?,?,?],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?,?],f32> return %7 : !torch.vtensor<[?,?,?,?,?],f32> } + +// CHECK-LABEL: func.func @test_resize_nearest_half_pixel +func.func @test_resize_nearest_half_pixel_round_prefer_floor(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[?,?,?],f32> { + // CHECK: %[[GENERIC:.*]] = linalg.generic + // CHECK: %[[x11:.*]] = linalg.index 0 : index + // CHECK: %[[x12:.*]] = linalg.index 1 : index + // CHECK: %[[x13:.*]] = linalg.index 2 : index + // CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64:.*]] : i64 to f32 + // CHECK: %[[x19:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32 + // CHECK: %[[x21:.*]] = arith.divf %[[x19]], %[[x15]] : f32 + // CHECK: %[[x23:.*]] = arith.index_cast %[[x13]] : index to i64 + // CHECK: %[[x24:.*]] = arith.sitofp %[[x23]] : i64 to f32 + // CHECK: %[[cst:.*]] = arith.constant 5.000000e-01 : f32 + // CHECK: %[[add:.*]] = arith.addf %[[x24]], %[[cst]] : f32 + // CHECK: %[[x25:.*]] = arith.divf %[[add]], %[[x21]] : f32 + // CHECK: %[[sub:.*]] = arith.subf %[[x25]], %[[cst]] : f32 + // CHECK: %[[cst3:.*]] = arith.constant 5.000000e-01 : f32 + // CHECK: %[[floor:.*]] = math.floor %[[sub]] : f32 + // CHECK: %[[ceil:.*]] = math.ceil %[[sub]] : f32 + // CHECK: %[[sub2:.*]] = arith.subf %[[sub]], %[[floor]] : f32 + // CHECK: %[[cmpf:.*]] = arith.cmpf ule, %[[sub2]], %[[cst3]] : f32 + // CHECK: %[[select:.*]] = arith.select %[[cmpf]], %[[floor]], %[[ceil]] : f32 + // CHECK: %[[x31:.*]] = arith.fptosi %[[select]] : f32 to i64 + // CHECK: %[[x32:.*]] = arith.index_cast %[[x31]] : i64 to index + // CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x11]], %[[x12]], %[[x32]]] : tensor + // CHECK: linalg.yield %[[extracted]] : f32 + %none = torch.constant.none + %none_0 = torch.constant.none + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %true = torch.constant.bool true + %str = torch.constant.str "nearest_half_pixel,round_prefer_floor" + %int2 = torch.constant.int 2 + %0 = torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + %1 = torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + %4 = torch.prim.ListConstruct %1 : (!torch.int) -> !torch.list + %5 = torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[?,?,?],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?],f32> + return %5 : !torch.vtensor<[?,?,?],f32> +} + +// ----- From 7cd3368b206bbcfb9cf272bfa7e532e60f574fc8 Mon Sep 17 00:00:00 2001 From: zjgarvey <47986913+zjgarvey@users.noreply.github.com> Date: Tue, 11 Jun 2024 22:35:50 -0500 Subject: [PATCH 09/12] [ONNX] Fix resize ceil numerics and add half_pixel_symmetric support (#3443) This patch fixes several failing tests in our [external test suite](https://github.com/nod-ai/SHARK-TestSuite/tree/main/iree_tests/onnx/node/generated), and addresses some of the issues discussed in #3420 --- .../TorchToLinalg/Uncategorized.cpp | 22 ++++- test/Conversion/TorchToLinalg/resize.mlir | 84 ++++++++++++++++++- 2 files changed, 104 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index a1c3003e32a4..1330174699a5 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -2657,14 +2657,21 @@ static Value NearestInterpolate(OpBuilder &b, Location loc, nearestFP = b.create(loc, cmp, floor, ceil); } else if (nearestMode == "round_prefer_ceil") { Value cstHalf = b.create(loc, b.getF32FloatAttr(0.5)); + Value cstOne = b.create(loc, b.getF32FloatAttr(1)); Value floor = b.create(loc, proj); Value ceil = b.create(loc, proj); Value decimal = b.create(loc, proj, floor); Value cmp = b.create(loc, arith::CmpFPredicate::UGE, decimal, cstHalf); nearestFP = b.create(loc, cmp, ceil, floor); + Value inputSizeMOne = b.create(loc, inputSizeFP, cstOne); + // don't extract out of bounds + nearestFP = b.create(loc, nearestFP, inputSizeMOne); } else if (nearestMode == "ceil") { + Value cstOne = b.create(loc, b.getF32FloatAttr(1)); + Value inputSizeMOne = b.create(loc, inputSizeFP, cstOne); nearestFP = b.create(loc, proj); + nearestFP = b.create(loc, nearestFP, inputSizeMOne); } else { llvm_unreachable("Unsupported nearest mode"); } @@ -2738,7 +2745,8 @@ static Value BilinearInterpolate(OpBuilder &b, if (coordStr == "_asymmetric") { preClip = b.create(loc, outFP, scale); } - if (coordStr == "_pytorch_half_pixel" || coordStr == "") { + if (coordStr == "_pytorch_half_pixel" || coordStr == "" || + coordStr == "_half_pixel_symmetric") { // half-pixel modes // y_resized + 0.5 Value outPlusHalf = b.create(loc, outFP, cstHalf); @@ -2747,6 +2755,18 @@ static Value BilinearInterpolate(OpBuilder &b, // _ - 0.5 preClip = b.create(loc, outDivScale, cstHalf); } + // for half_pixel_symmetric, need to compute offset from raw scales + if (coordStr == "_half_pixel_symmetric" && !scaleValues.empty()) { + Value outputSizeFromScale = b.create(loc, inputFP, scale); + Value adjustment = + b.create(loc, outputSizeFP, outputSizeFromScale); + Value cstTwo = b.create(loc, b.getF32FloatAttr(2.0)); + Value center = b.create(loc, inputFP, cstTwo); + Value oneMAdjustment = + b.create(loc, cstOneFloat, adjustment); + Value offset = b.create(loc, center, oneMAdjustment); + preClip = b.create(loc, offset, preClip); + } // for pytorch half pixel , special case for length_resized == 1: if (coordStr == "_pytorch_half_pixel") { Value cmp = b.create(loc, arith::CmpFPredicate::UEQ, diff --git a/test/Conversion/TorchToLinalg/resize.mlir b/test/Conversion/TorchToLinalg/resize.mlir index 6847d25736f1..64198d03f2a1 100644 --- a/test/Conversion/TorchToLinalg/resize.mlir +++ b/test/Conversion/TorchToLinalg/resize.mlir @@ -156,7 +156,89 @@ func.func @test_resize_nearest_3d(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: return %7 : !torch.vtensor<[?,?,?,?,?],f32> } -// CHECK-LABEL: func.func @test_resize_nearest_half_pixel +// ----- + +// CHECK-LABEL: func.func @test_resize_nearest_ceil +func.func @test_resize_nearest_ceil(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[?,?,?],f32> { + // CHECK: %[[GENERIC:.*]] = linalg.generic + // CHECK: %[[x11:.*]] = linalg.index 0 : index + // CHECK: %[[x12:.*]] = linalg.index 1 : index + // CHECK: %[[x13:.*]] = linalg.index 2 : index + // CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64:.*]] : i64 to f32 + // CHECK: %[[x19:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32 + // CHECK: %[[x21:.*]] = arith.divf %[[x19]], %[[x15]] : f32 + // CHECK: %[[x23:.*]] = arith.index_cast %[[x13]] : index to i64 + // CHECK: %[[x24:.*]] = arith.sitofp %[[x23]] : i64 to f32 + // CHECK: %[[cst:.*]] = arith.constant 5.000000e-01 : f32 + // CHECK: %[[add:.*]] = arith.addf %[[x24]], %[[cst]] : f32 + // CHECK: %[[x25:.*]] = arith.divf %[[add]], %[[x21]] : f32 + // CHECK: %[[sub:.*]] = arith.subf %[[x25]], %[[cst]] : f32 + // CHECK: %[[cst3:.*]] = arith.constant 1.000000e+00 : f32 + // CHECK: %[[nM1:.*]] = arith.subf %[[inputsizefp:.*]], %[[cst3]] + // CHECK: %[[ceil:.*]] = math.ceil %[[sub]] : f32 + // CHECK: %[[minindex:.*]] = arith.minimumf %[[ceil]], %[[nM1]] + // CHECK: %[[x31:.*]] = arith.fptosi %[[minindex]] : f32 to i64 + // CHECK: %[[x32:.*]] = arith.index_cast %[[x31]] : i64 to index + // CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x11]], %[[x12]], %[[x32]]] : tensor + // CHECK: linalg.yield %[[extracted]] : f32 + %none = torch.constant.none + %none_0 = torch.constant.none + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %true = torch.constant.bool true + %str = torch.constant.str "nearest_half_pixel,ceil" + %int2 = torch.constant.int 2 + %0 = torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + %1 = torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + %4 = torch.prim.ListConstruct %1 : (!torch.int) -> !torch.list + %5 = torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[?,?,?],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?],f32> + return %5 : !torch.vtensor<[?,?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_resize_scales_linear_half_pixel_symmetric +func.func @test_resize_scales_linear_half_pixel_symmetric(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4] +,f64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[generic:.*]] = linalg.generic + // CHECK: %[[cst7:.*]] = arith.constant 2.0 + // CHECK: %[[halfsize:.*]] = arith.divf %[[sizefp:.*]], %[[cst7]] + // CHECK: %[[modifier:.*]] = arith.subf %[[cstOne:.*]], %[[adjustment:.*]] + // CHECK: %[[offset:.*]] = arith.mulf %[[halfsize]], %[[modifier]] + // CHECK: %[[preClip:.*]] = arith.addf %[[offset]], %[[halfpixelbase:.*]] + // CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x1:.*]], %[[x2:.*]], %[[x3:.*]], %[[x4:.*]]] : tensor<1x1x2x4xf32> + // CHECK: %[[extracted_7:.*]] = tensor.extract %[[x0]][%[[x1]], %[[x2]] + // CHECK: %[[extracted_8:.*]] = tensor.extract %[[x0]][%[[x1]], %[[x2]] + // CHECK: %[[extracted_9:.*]] = tensor.extract %[[x0]][%[[x1]], %[[x2]] + // CHECK: %[[dx0p00:.*]] = arith.mulf %[[dx0:.*]], %[[extracted]] + // CHECK: %[[dx1p01:.*]] = arith.mulf %[[dx1:.*]], %[[extracted_7]] + // CHECK: %[[sum:.*]] = arith.addf %[[dx0p00]], %[[dx1p01]] + // CHECK: %[[left:.*]] = arith.mulf %[[dy0:.*]], %[[sum]] + // CHECK: %[[dx0p10:.*]] = arith.mulf %[[dx0]], %[[extracted_8]] + // CHECK: %[[dx1p11:.*]] = arith.mulf %[[dx1]], %[[extracted_9]] + // CHECK: %[[sum2:.*]] = arith.addf %[[dx0p10]], %[[dx1p11]] + // CHECK: %[[right:.*]] = arith.mulf %[[dy1:.*]], %[[sum2]] + // CHECK: %[[retval:.*]] = arith.addf %[[left]], %[[right]] + %none = torch.constant.none + %none_0 = torch.constant.none + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %true = torch.constant.bool true + %str = torch.constant.str "bilinear_half_pixel_symmetric" + %int2 = torch.constant.int 2 + %0 = torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[4],f64>, !torch.int, !torch.int -> !torch.vtensor<[1],f64> + %1 = torch.aten.item %0 : !torch.vtensor<[1],f64> -> !torch.float + %int3 = torch.constant.int 3 + %2 = torch.aten.select.int %arg1, %int0, %int3 : !torch.vtensor<[4],f64>, !torch.int, !torch.int -> !torch.vtensor<[1],f64> + %3 = torch.aten.item %2 : !torch.vtensor<[1],f64> -> !torch.float + %4 = torch.prim.ListConstruct %1, %3 : (!torch.float, !torch.float) -> !torch.list + %5 = torch.aten.__interpolate.size_list_scale_list %arg0, %none_0, %4, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.list, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + return %5 : !torch.vtensor<[?,?,?,?],f32> + } + +// ----- + +// CHECK-LABEL: func.func @test_resize_nearest_half_pixel_round_prefer_floor func.func @test_resize_nearest_half_pixel_round_prefer_floor(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[?,?,?],f32> { // CHECK: %[[GENERIC:.*]] = linalg.generic // CHECK: %[[x11:.*]] = linalg.index 0 : index From de28c8540b3d08fa685dd397c170e609323a79ce Mon Sep 17 00:00:00 2001 From: zjgarvey <47986913+zjgarvey@users.noreply.github.com> Date: Wed, 12 Jun 2024 00:07:22 -0500 Subject: [PATCH 10/12] [ONNX] add int16 quantization support (#3446) There is currently no int16 quantization support in torch. This patch adds a new mlir type to correspond to the missing "torch.qint16" type, and enables lowering of quantization-related onnx ops using int16 types. In follow-up patches, custom quantization logic for ops like aten.matmul/aten.mm/aten.convolution may need to be revisited to allow support for qint16. The passes in FuseQuantizedOps.cpp may also need slight modifications. --- include/torch-mlir-c/TorchTypes.h | 13 +++++++++++++ .../Conversion/TorchOnnxToTorch/Utils.h | 2 +- .../torch-mlir/Dialect/Torch/IR/TorchTypes.td | 10 ++++++++++ .../Dialect/Torch/Utils/TorchUpstream.h | 3 ++- lib/CAPI/TorchTypes.cpp | 16 ++++++++++++++++ .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 13 ++----------- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 17 ++++------------- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 18 ++++-------------- lib/Conversion/TorchOnnxToTorch/Utils.cpp | 5 ++++- lib/Conversion/TorchToLinalg/Utils.cpp | 2 ++ lib/Dialect/Torch/IR/TorchTypes.cpp | 6 +++++- .../Torch/Transforms/MatchQuantizedOps.cpp | 4 +++- lib/Dialect/Torch/Utils/TorchUpstream.cpp | 2 +- lib/Dialect/Torch/Utils/Utils.cpp | 4 ++++ .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 13 +++++++++++++ test/Dialect/Torch/ops.mlir | 1 + 16 files changed, 85 insertions(+), 44 deletions(-) diff --git a/include/torch-mlir-c/TorchTypes.h b/include/torch-mlir-c/TorchTypes.h index b214e147d5d9..dd7cfb5c428f 100644 --- a/include/torch-mlir-c/TorchTypes.h +++ b/include/torch-mlir-c/TorchTypes.h @@ -220,6 +220,19 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchQUInt8TypeGet(MlirContext context); /// Gets the !torch.quint8 typeid. MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchQUInt8TypeGetTypeID(void); +//===----------------------------------------------------------------------===// +// torch.qint16 type. +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is a !torch.qint16 type +MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchQInt16(MlirType t); + +/// Gets the !torch.qint16 type. +MLIR_CAPI_EXPORTED MlirType torchMlirTorchQInt16TypeGet(MlirContext context); + +/// Gets the !torch.qint16 typeid. +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchQInt16TypeGetTypeID(void); + //===----------------------------------------------------------------------===// // torch.tensor type. //===----------------------------------------------------------------------===// diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h index d8d2534f9a0c..4bf6c845c68a 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h @@ -36,7 +36,7 @@ Value createConstantIntList(OpBinder binder, ConversionPatternRewriter &rewriter, SmallVector cstInput); -Type getQTorchTypeFromTorchIntType(Type ty); +Torch::ValueTensorType getQTorchTypeFromTorchIntType(Type ty); template Value getItemOp(OpBinder binder, ConversionPatternRewriter &rewriter, diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td index 279e694540f9..367b08610cd8 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td @@ -315,6 +315,16 @@ def Torch_QInt8Type : Torch_Type<"QInt8", "qint8"> { }]; } +def Torch_QInt16Type : Torch_Type<"QInt16", "qint16"> { + let summary = "Type modeling `ScalarType::QInt16`, which doesn't yet exist"; + let description = [{ + Pytorch does not have 16-bit integer quantization support. + + This torch type is added to provide a target for 16-bit quantization + schemes coming from imported onnx models. + }]; +} + def Torch_QUInt8Type : Torch_Type<"QUInt8", "quint8"> { let summary = "Type modeling `ScalarType::QUInt8`"; let description = [{ diff --git a/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h b/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h index 3d2c8bb588d7..e2b57538d7e6 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h +++ b/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h @@ -112,7 +112,8 @@ enum class TypeKind { _(c10::Float8_e5m2, Float8_e5m2) /* 23 */ \ _(c10::Float8_e4m3fn, Float8_e4m3fn) /* 24 */ \ _(c10::Float8_e5m2fnuz, Float8_e5m2fnuz) /* 25 */ \ - _(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) /* 26 */ + _(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) /* 26 */ \ + _(c10::qint16, QInt16) /* 27 */ enum class ScalarType : int8_t { #define DEFINE_ENUM(_1, n) n, diff --git a/lib/CAPI/TorchTypes.cpp b/lib/CAPI/TorchTypes.cpp index 399915459e40..6402e44a3701 100644 --- a/lib/CAPI/TorchTypes.cpp +++ b/lib/CAPI/TorchTypes.cpp @@ -269,6 +269,22 @@ MlirTypeID torchMlirTorchQUInt8TypeGetTypeID() { return wrap(Torch::QUInt8Type::getTypeID()); } +//===----------------------------------------------------------------------===// +// torch.qint16 type. +//===----------------------------------------------------------------------===// + +bool torchMlirTypeIsATorchQInt16(MlirType t) { + return isa(unwrap(t)); +} + +MlirType torchMlirTorchQInt16TypeGet(MlirContext context) { + return wrap(Torch::QInt16Type::get(unwrap(context))); +} + +MlirTypeID torchMlirTorchQInt16TypeGetTypeID() { + return wrap(Torch::QInt16Type::getTypeID()); +} + //===----------------------------------------------------------------------===// // torch.tensor type. //===----------------------------------------------------------------------===// diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 31deadcafb7f..d0ff6e973a7e 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -1715,21 +1715,12 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( "requires known result dtype"); if (scaleTy.getSizes().size() == 0 || (scaleTy.getSizes().size() == 1 && scaleTy.getSizes()[0] == 1)) { - Type qTy = operandTy.getDtype(); - - if (qTy.isUnsignedInteger(8)) { - qTy = rewriter.getType(); - } else if (qTy.isSignedInteger(8)) { - qTy = rewriter.getType(); - } else if (qTy.isSignedInteger(32)) { - qTy = rewriter.getType(); - } else { + auto qTensorTy = getQTorchTypeFromTorchIntType(operandTy); + if (!qTensorTy) { return rewriter.notifyMatchFailure(binder.op, "unsupported result dtype"); } - auto qTensorTy = rewriter.getType( - resultType.getOptionalSizes(), qTy); scale = rewriter.create( binder.getLoc(), rewriter.getType(), scale); zeropoint = rewriter.create( diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index c7e41a7a097c..26f4ddb677ec 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -408,20 +408,11 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(1.0)); - auto q = [&](Type qty) -> Type { - if (qty.isSignedInteger(8)) - return rewriter.getType(); - if (qty.isUnsignedInteger(8)) - return rewriter.getType(); - if (qty.isSignedInteger(32)) - return rewriter.getType(); - return {}; - }; + auto lhsQTy = getQTorchTypeFromTorchIntType(lhsTy); + auto rhsQTy = getQTorchTypeFromTorchIntType(rhsTy); - Type lhsQTy = rewriter.getType( - lhsTy.getOptionalSizes(), q(lhsTy.getDtype())); - Type rhsQTy = rewriter.getType( - rhsTy.getOptionalSizes(), q(rhsTy.getDtype())); + if (!lhsQTy || !rhsQTy) + return rewriter.notifyMatchFailure(binder.op, "failed to get qtype"); lhs = rewriter.create( binder.getLoc(), lhsQTy, lhs, scale, lhsZp); diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 67370567ad6b..381063096776 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -177,22 +177,13 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( "requires known result dtype"); if (scaleTy.getSizes().size() == 0) { - Type qTy = resultType.getDtype(); - - if (qTy.isUnsignedInteger(8)) { - qTy = rewriter.getType(); - } else if (qTy.isSignedInteger(8)) { - qTy = rewriter.getType(); - } else if (qTy.isSignedInteger(32)) { - qTy = rewriter.getType(); - } else { + auto qTensorTy = getQTorchTypeFromTorchIntType(resultType); + if (!qTensorTy) { return rewriter.notifyMatchFailure(binder.op, "unsupported result dtype"); } - auto qTensorTy = rewriter.getType( - resultType.getOptionalSizes(), qTy); - auto torchqTy = Torch::getScalarTypeForType(qTy); + auto torchqTy = Torch::getScalarTypeForType(qTensorTy.getDtype()); Value tyConst = rewriter.create( binder.getLoc(), rewriter.getType(), @@ -311,8 +302,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( c = rewriter.create(binder.getLoc(), cTy, c); - cTy = dyn_cast( - getQTorchTypeFromTorchIntType(resultType)); + cTy = getQTorchTypeFromTorchIntType(resultType); Value dtyVal = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr( diff --git a/lib/Conversion/TorchOnnxToTorch/Utils.cpp b/lib/Conversion/TorchOnnxToTorch/Utils.cpp index e7baf2e243fc..bec6ade4270c 100644 --- a/lib/Conversion/TorchOnnxToTorch/Utils.cpp +++ b/lib/Conversion/TorchOnnxToTorch/Utils.cpp @@ -28,7 +28,8 @@ Value mlir::torch::onnx_c::createConstantIntList( cstValue); } -Type mlir::torch::onnx_c::getQTorchTypeFromTorchIntType(Type ty) { +Torch::ValueTensorType +mlir::torch::onnx_c::getQTorchTypeFromTorchIntType(Type ty) { Torch::ValueTensorType tty = dyn_cast(ty); if (!tty) return nullptr; @@ -40,6 +41,8 @@ Type mlir::torch::onnx_c::getQTorchTypeFromTorchIntType(Type ty) { dty = Torch::QUInt8Type::get(ctx); if (dty.isSignedInteger(8)) dty = Torch::QInt8Type::get(ctx); + if (dty.isSignedInteger(16)) + dty = Torch::QInt16Type::get(ctx); if (dty.isSignedInteger(32)) dty = Torch::QInt32Type::get(ctx); diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index 7355327461d4..c2658f35cce3 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -565,6 +565,8 @@ bool torch_to_linalg::isUnsignedTorchType(Type type) { return false; if (isa(type)) return true; + if (isa(type)) + return false; if (isa(type)) return false; if (auto intTy = dyn_cast(type)) diff --git a/lib/Dialect/Torch/IR/TorchTypes.cpp b/lib/Dialect/Torch/IR/TorchTypes.cpp index 12aea1589a4d..c46865ee5fed 100644 --- a/lib/Dialect/Torch/IR/TorchTypes.cpp +++ b/lib/Dialect/Torch/IR/TorchTypes.cpp @@ -185,7 +185,8 @@ static bool isValidTorchDtype(Type dtype) { dtype = cast(dtype).getElementType(); } // Torch quantized types. - if (isa(dtype)) + if (isa(dtype)) return true; // Builtin floating point types. if (isa(dtype)) @@ -463,6 +464,9 @@ static Type convertDtypeToBuiltinElementType(MLIRContext *context, Type dtype) { if (isa(dtype)) return IntegerType::get(context, 8, IntegerType::Signless); + if (isa(dtype)) + return IntegerType::get(context, 16, IntegerType::Signless); + if (isa(dtype)) return IntegerType::get(context, 32, IntegerType::Signless); diff --git a/lib/Dialect/Torch/Transforms/MatchQuantizedOps.cpp b/lib/Dialect/Torch/Transforms/MatchQuantizedOps.cpp index c237ede12479..b571003940cb 100644 --- a/lib/Dialect/Torch/Transforms/MatchQuantizedOps.cpp +++ b/lib/Dialect/Torch/Transforms/MatchQuantizedOps.cpp @@ -21,10 +21,12 @@ using namespace mlir::torch::Torch; namespace { Type getQuantizedType(MLIRContext *context, Type t) { - if (t.isSignlessInteger(8)) + if (t.isSignlessInteger(8) || t.isUnsignedInteger(8)) return Torch::QUInt8Type::get(context); if (t.isInteger(8) || t.isSignedInteger(8)) return Torch::QInt8Type::get(context); + if (t.isInteger(16)) + return Torch::QInt16Type::get(context); if (t.isInteger(32)) return Torch::QInt32Type::get(context); return {}; diff --git a/lib/Dialect/Torch/Utils/TorchUpstream.cpp b/lib/Dialect/Torch/Utils/TorchUpstream.cpp index 2dce14ef964c..c4c42f7fe0e3 100644 --- a/lib/Dialect/Torch/Utils/TorchUpstream.cpp +++ b/lib/Dialect/Torch/Utils/TorchUpstream.cpp @@ -21,7 +21,7 @@ static inline bool isQIntType(ScalarType t) { // Don't forget to extend this when adding new QInt types return t == ScalarType::QInt8 || t == ScalarType::QUInt8 || t == ScalarType::QInt32 || t == ScalarType::QUInt4x2 || - t == ScalarType::QUInt2x4; + t == ScalarType::QUInt2x4 || t == ScalarType::QInt16; } //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 388c38b25cb3..81a2de87b054 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -69,6 +69,8 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) { return torch_upstream::ScalarType::QUInt8; if (isa(type)) return torch_upstream::ScalarType::QInt8; + if (isa(type)) + return torch_upstream::ScalarType::QInt16; if (isa(type)) return torch_upstream::ScalarType::QInt32; if (isa(type)) { @@ -128,6 +130,8 @@ Torch::getTypeForScalarType(MLIRContext *context, return QUInt8Type::get(context); case torch_upstream::ScalarType::QInt8: return QInt8Type::get(context); + case torch_upstream::ScalarType::QInt16: + return QInt16Type::get(context); case torch_upstream::ScalarType::QInt32: return QInt32Type::get(context); case torch_upstream::ScalarType::ComplexHalf: diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 3f437fc4c5c1..5b33fd17471b 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -748,6 +748,19 @@ func.func @test_dequantizelinear_si8(%arg0: !torch.vtensor<[6],si8>, %arg1: !tor // ----- +// CHECK-LABEL: @test_dequantizelinear_si16 +func.func @test_dequantizelinear_si16(%arg0: !torch.vtensor<[6],si16>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],si16>) -> !torch.vtensor<[6],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64} { + %0 = torch.operator "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) : (!torch.vtensor<[6],si16>, !torch.vtensor<[],f32>, !torch.vtensor<[],si16>) -> !torch.vtensor<[6],f32> + // CHECK: %[[SCALE:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[ZP:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],si16> -> !torch.int + // CHECK: %[[MAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[SCALE]], %[[ZP]] + // CHECK: %[[DEQ:.+]] = torch.aten.dequantize.self %[[MAKE]] + // CHECK: return %[[DEQ]] + return %0 : !torch.vtensor<[6],f32> +} + +// ----- + // CHECK-LABEL: @test_dequantizelinear_ui8 func.func @test_dequantizelinear_ui8(%arg0: !torch.vtensor<[6],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>) -> !torch.vtensor<[6],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64} { %0 = torch.operator "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) : (!torch.vtensor<[6],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>) -> !torch.vtensor<[6],f32> diff --git a/test/Dialect/Torch/ops.mlir b/test/Dialect/Torch/ops.mlir index ecf5e626fb1d..1fdbf6e1d7d3 100644 --- a/test/Dialect/Torch/ops.mlir +++ b/test/Dialect/Torch/ops.mlir @@ -171,6 +171,7 @@ func.func @number_type_subtypes(%arg0: !torch.tensor, %arg1: !torch.list, % func.func private @tensor_legal_dtype$torch.qint8() -> !torch.tensor<*,!torch.qint8> func.func private @tensor_legal_dtype$torch.quint8() -> !torch.tensor<*,!torch.quint8> +func.func private @tensor_legal_dtype$torch.qint16() -> !torch.tensor<*,!torch.qint16> func.func @prim_list_construct$valid_shape_subtype(%arg0: !torch.vtensor<[1,53,56,96],f16>, %arg1: !torch.vtensor<[1,3,56,96],f16>) -> !torch.list> { %arg2 = "torch.prim.ListConstruct"(%arg0, %arg1) : (!torch.vtensor<[1,53,56,96],f16>, !torch.vtensor<[1,3,56,96],f16>) -> !torch.list> From c0eb6d89c02c7e23cf213f97556dcc567b20cec9 Mon Sep 17 00:00:00 2001 From: zjgarvey <47986913+zjgarvey@users.noreply.github.com> Date: Wed, 12 Jun 2024 10:55:14 -0500 Subject: [PATCH 11/12] [ONNX] add some args to the onnx importer to assist shape_inference (#3445) Adds the following arguments: - "--clear-domain": enabling this flag (default False) will delete the domain attribute from each node in the onnx model before importing. Shape inference does not seem to work for onnx ops in custom domains. In the rare case when these ops have a corresponding counterpart in base onnx, enabling this flag might allow shape inference to work properly. - "--opset-version": allows setting the opset version manually. This will cause the importer to attempt to update the opset_version of the onnx model before importing. Newer opset versions sometimes have more robust shape inference patterns. --- python/torch_mlir/extras/onnx_importer.py | 5 ++++ .../torch_mlir/tools/import_onnx/__main__.py | 25 +++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/python/torch_mlir/extras/onnx_importer.py b/python/torch_mlir/extras/onnx_importer.py index e0d3529d942e..4c1e0b9e9aed 100644 --- a/python/torch_mlir/extras/onnx_importer.py +++ b/python/torch_mlir/extras/onnx_importer.py @@ -34,6 +34,7 @@ ) from e from typing import Optional, List, Dict, Tuple +import warnings from dataclasses import dataclass @@ -579,6 +580,10 @@ def tensor_proto_to_builtin_type(self, tp: onnx.TensorProto) -> IrType: def type_proto_to_type(self, tp: onnx.TypeProto) -> IrType: if tp == "": + warnings.warn( + "Found a node without a valid type proto. Consider updating the opset_version of" + " the model and/or running the importer with the flag '--clear-domain'." + ) return self.get_none_type() tt = tp.tensor_type diff --git a/python/torch_mlir/tools/import_onnx/__main__.py b/python/torch_mlir/tools/import_onnx/__main__.py index 92ae3c7eb356..bca87cee7f59 100644 --- a/python/torch_mlir/tools/import_onnx/__main__.py +++ b/python/torch_mlir/tools/import_onnx/__main__.py @@ -20,6 +20,7 @@ import sys import onnx +import onnx.version from ...extras import onnx_importer @@ -81,6 +82,16 @@ def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto: raw_model = onnx.load(args.input_file, load_external_data=False) onnx.load_external_data_for_model(raw_model, args.data_dir) + if args.opset_version: + raw_model = onnx.version_converter.convert_version( + raw_model, args.opset_version + ) + + if args.clear_domain: + graph = raw_model.graph + for n in graph.node: + n.ClearField("domain") + # Run the checker to test whether the file is above the threshold for # in-memory shape inference. If not, go ahead and do the shape inference. try: @@ -149,6 +160,14 @@ def parse_arguments(argv=None) -> argparse.Namespace: action=argparse.BooleanOptionalAction, help="Toggle data propogation for onnx shape inference", ) + parser.add_argument( + "--clear-domain", + dest="clear_domain", + default=False, + action=argparse.BooleanOptionalAction, + help="If enabled, this will clear the domain attribute from each node" + " in the onnx graph before performing shape inference.", + ) parser.add_argument( "--keep-temps", action="store_true", help="Keep intermediate files" ) @@ -170,6 +189,12 @@ def parse_arguments(argv=None) -> argparse.Namespace: " Defaults to the directory of the input file.", type=Path, ) + parser.add_argument( + "--opset-version", + help="Allows specification of a newer opset_version to update the model" + " to before importing to MLIR. This can sometime assist with shape inference.", + type=int, + ) args = parser.parse_args(argv) return args From 41d04a89959d9197e167302fcae375f947848a88 Mon Sep 17 00:00:00 2001 From: Suraj Sudhir Date: Wed, 12 Jun 2024 09:23:42 -0700 Subject: [PATCH 12/12] [onnx] Resize supports default-valued attributes (#3450) Handles onnx exporters emitting default-valued attributes. Signed-off-by: Suraj Sudhir --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 36 +++++++++++-------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 381063096776..6b003b1259c0 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2771,28 +2771,15 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Torch::ValueTensorType resultType; llvm::SmallVector operands; std::string mode, nearest_mode, coordTfMode; + int64_t antialias, exclude_outside; + float extrapolation_value; Value noneVal = rewriter.create(binder.getLoc()); - if (auto attr = binder.op->getAttr("torch.onnx.antialias")) { - return rewriter.notifyMatchFailure( - binder.op, - "unimplemented: support not present for antialias attribute"); - } if (auto attr = binder.op->getAttr("torch.onnx.axes")) { return rewriter.notifyMatchFailure( binder.op, "unimplemented: support not present for axes attribute"); } - if (auto attr = binder.op->getAttr("torch.onnx.exclude_outside")) { - return rewriter.notifyMatchFailure( - binder.op, "unimplemented: support not present for " - "exclude_outside attribute"); - } - if (auto attr = binder.op->getAttr("torch.onnx.extrapolation_value")) { - return rewriter.notifyMatchFailure( - binder.op, "unimplemented: support not present for " - "extrapolation_value attribute"); - } if (auto attr = binder.op->getAttr("torch.onnx.keep_aspect_ratio_policy")) { return rewriter.notifyMatchFailure( @@ -2805,9 +2792,28 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.customOpNameStringAttr(mode, "mode", "nearest") || binder.customOpNameStringAttr( coordTfMode, "coordinate_transformation_mode", "half_pixel") || + binder.s64IntegerAttr(antialias, "antialias", 0) || + binder.s64IntegerAttr(exclude_outside, "exclude_outside", 0) || + binder.f32FloatAttr(extrapolation_value, "extrapolation_value", + 0.0) || binder.customOpNameStringAttr(nearest_mode, "nearest_mode", "round_prefer_floor")) return failure(); + if (antialias != 0) { + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented: support not present for antialias attribute"); + } + if (exclude_outside != 0) { + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: support not present for " + "exclude_outside attribute"); + } + if (extrapolation_value != 0.0) { + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: support not present for " + "extrapolation_value attribute"); + } if (coordTfMode == "tf_crop_and_resize") return rewriter.notifyMatchFailure( binder.op, "unimplemented: coordinate transformation mode: "