Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoBump] Merge with 41d04a89 (Jun 12) (64) #301

Merged
merged 13 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions include/torch-mlir-c/TorchTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,19 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchQUInt8TypeGet(MlirContext context);
/// Gets the !torch.quint8 typeid.
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchQUInt8TypeGetTypeID(void);

//===----------------------------------------------------------------------===//
// torch.qint16 type.
//===----------------------------------------------------------------------===//

/// Checks whether the given type is a !torch.qint16 type
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchQInt16(MlirType t);

/// Gets the !torch.qint16 type.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchQInt16TypeGet(MlirContext context);

/// Gets the !torch.qint16 typeid.
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchQInt16TypeGetTypeID(void);

//===----------------------------------------------------------------------===//
// torch.tensor type.
//===----------------------------------------------------------------------===//
Expand Down
12 changes: 12 additions & 0 deletions include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,18 @@ struct OpBinder {
return success();
}

ParseResult tensorListOperandAtIndex(Value &valueIdx, int64_t idx) {
if (idx >= op->getNumOperands())
return failure();
valueIdx = op->getOperand(idx);
auto tt = dyn_cast<Torch::ListType>(valueIdx.getType());
if (!tt)
return failure();
if (!toValidTensorType(tt.getContainedType()))
return failure();
return success();
}

ParseResult tensorListResultType(Torch::ListType &type0) {
if (op->getNumResults() != 1)
return failure();
Expand Down
2 changes: 1 addition & 1 deletion include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ Value createConstantIntList(OpBinder binder,
ConversionPatternRewriter &rewriter,
SmallVector<int64_t> cstInput);

Type getQTorchTypeFromTorchIntType(Type ty);
Torch::ValueTensorType getQTorchTypeFromTorchIntType(Type ty);

template <typename T>
Value getItemOp(OpBinder binder, ConversionPatternRewriter &rewriter,
Expand Down
10 changes: 10 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/TorchTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,16 @@ def Torch_QInt8Type : Torch_Type<"QInt8", "qint8"> {
}];
}

def Torch_QInt16Type : Torch_Type<"QInt16", "qint16"> {
let summary = "Type modeling `ScalarType::QInt16`, which doesn't yet exist";
let description = [{
Pytorch does not have 16-bit integer quantization support.

This torch type is added to provide a target for 16-bit quantization
schemes coming from imported onnx models.
}];
}

def Torch_QUInt8Type : Torch_Type<"QUInt8", "quint8"> {
let summary = "Type modeling `ScalarType::QUInt8`";
let description = [{
Expand Down
46 changes: 28 additions & 18 deletions include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,24 +86,34 @@ enum class TypeKind {
// at:: and c10:: parts of the macro are never used within the compiler -- we
// only use this for the enum values.
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(_) \
_(uint8_t, Byte) /* 0 */ \
_(int8_t, Char) /* 1 */ \
_(int16_t, Short) /* 2 */ \
_(int, Int) /* 3 */ \
_(int64_t, Long) /* 4 */ \
_(at::Half, Half) /* 5 */ \
_(float, Float) /* 6 */ \
_(double, Double) /* 7 */ \
_(c10::complex<c10::Half>, ComplexHalf) /* 8 */ \
_(c10::complex<float>, ComplexFloat) /* 9 */ \
_(c10::complex<double>, ComplexDouble) /* 10 */ \
_(bool, Bool) /* 11 */ \
_(c10::qint8, QInt8) /* 12 */ \
_(c10::quint8, QUInt8) /* 13 */ \
_(c10::qint32, QInt32) /* 14 */ \
_(at::BFloat16, BFloat16) /* 15 */ \
_(c10::quint4x2, QUInt4x2) /* 16 */ \
_(c10::quint2x4, QUInt2x4) /* 17 */
_(uint8_t, Byte) /* 0 */ \
_(int8_t, Char) /* 1 */ \
_(int16_t, Short) /* 2 */ \
_(int, Int) /* 3 */ \
_(int64_t, Long) /* 4 */ \
_(at::Half, Half) /* 5 */ \
_(float, Float) /* 6 */ \
_(double, Double) /* 7 */ \
_(c10::complex<c10::Half>, ComplexHalf) /* 8 */ \
_(c10::complex<float>, ComplexFloat) /* 9 */ \
_(c10::complex<double>, ComplexDouble) /* 10 */ \
_(bool, Bool) /* 11 */ \
_(c10::qint8, QInt8) /* 12 */ \
_(c10::quint8, QUInt8) /* 13 */ \
_(c10::qint32, QInt32) /* 14 */ \
_(at::BFloat16, BFloat16) /* 15 */ \
_(c10::quint4x2, QUInt4x2) /* 16 */ \
_(c10::quint2x4, QUInt2x4) /* 17 */ \
_(c10::bits1x8, Bits1x8) /* 18 */ \
_(c10::bits2x4, Bits2x4) /* 19 */ \
_(c10::bits4x2, Bits4x2) /* 20 */ \
_(c10::bits8, Bits8) /* 21 */ \
_(c10::bits16, Bits16) /* 22 */ \
_(c10::Float8_e5m2, Float8_e5m2) /* 23 */ \
_(c10::Float8_e4m3fn, Float8_e4m3fn) /* 24 */ \
_(c10::Float8_e5m2fnuz, Float8_e5m2fnuz) /* 25 */ \
_(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) /* 26 */ \
_(c10::qint16, QInt16) /* 27 */

enum class ScalarType : int8_t {
#define DEFINE_ENUM(_1, n) n,
Expand Down
16 changes: 16 additions & 0 deletions lib/CAPI/TorchTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,22 @@ MlirTypeID torchMlirTorchQUInt8TypeGetTypeID() {
return wrap(Torch::QUInt8Type::getTypeID());
}

//===----------------------------------------------------------------------===//
// torch.qint16 type.
//===----------------------------------------------------------------------===//

bool torchMlirTypeIsATorchQInt16(MlirType t) {
return isa<Torch::QInt16Type>(unwrap(t));
}

MlirType torchMlirTorchQInt16TypeGet(MlirContext context) {
return wrap(Torch::QInt16Type::get(unwrap(context)));
}

MlirTypeID torchMlirTorchQInt16TypeGetTypeID() {
return wrap(Torch::QInt16Type::getTypeID());
}

//===----------------------------------------------------------------------===//
// torch.tensor type.
//===----------------------------------------------------------------------===//
Expand Down
33 changes: 12 additions & 21 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
std::numeric_limits<float>::lowest()))
return failure();
auto minSplatAttr = SplatElementsAttr::get(
resultType.toBuiltinTensor().clone(resultDtype),
resultType.toBuiltinTensor(),
rewriter.getFloatAttr(resultDtype, minValue));
min = rewriter.create<Torch::ValueTensorLiteralOp>(
binder.getLoc(), resultType, minSplatAttr);
Expand All @@ -748,7 +748,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
std::numeric_limits<float>::max()))
return failure();
auto maxSplatAttr = SplatElementsAttr::get(
resultType.toBuiltinTensor().clone(resultDtype),
resultType.toBuiltinTensor(),
rewriter.getFloatAttr(resultDtype, maxValue));
max = rewriter.create<Torch::ValueTensorLiteralOp>(
binder.getLoc(), resultType, maxSplatAttr);
Expand Down Expand Up @@ -829,7 +829,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
return success();
});
patterns.onOp(
"Concat", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
"Concat", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
SmallVector<Value> tensors;
int64_t dim;
Expand Down Expand Up @@ -861,7 +861,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
if (binder.op->hasAttr("torch.onnx.value_float") &&
!binder.f32FloatAttr(floatValue, "value_float", 0.0)) {
auto splatAttr =
SplatElementsAttr::get(resultType.toBuiltinTensor().clone(dtype),
SplatElementsAttr::get(resultType.toBuiltinTensor(),
rewriter.getFloatAttr(dtype, floatValue));
rewriter.replaceOpWithNewOp<Torch::ValueTensorLiteralOp>(
binder.op, resultType, splatAttr);
Expand All @@ -872,7 +872,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
if (binder.op->hasAttr("torch.onnx.value_int") &&
!binder.s64IntegerAttr(intValue, "value_int", 0)) {
auto splatAttr =
SplatElementsAttr::get(resultType.toBuiltinTensor().clone(dtype),
SplatElementsAttr::get(resultType.toBuiltinTensor(),
rewriter.getIntegerAttr(dtype, intValue));
rewriter.replaceOpWithNewOp<Torch::ValueTensorLiteralOp>(
binder.op, resultType, splatAttr);
Expand Down Expand Up @@ -932,8 +932,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
for (auto intVal : intValues) {
apValues.push_back(APInt(dtype.getIntOrFloatBitWidth(), intVal));
}
auto attr = DenseElementsAttr::get(
resultType.toBuiltinTensor().clone(dtype), apValues);
auto attr =
DenseElementsAttr::get(resultType.toBuiltinTensor(), apValues);
rewriter.replaceOpWithNewOp<Torch::ValueTensorLiteralOp>(
binder.op, resultType, attr);
return success();
Expand Down Expand Up @@ -1715,21 +1715,12 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
"requires known result dtype");
if (scaleTy.getSizes().size() == 0 ||
(scaleTy.getSizes().size() == 1 && scaleTy.getSizes()[0] == 1)) {
Type qTy = operandTy.getDtype();

if (qTy.isUnsignedInteger(8)) {
qTy = rewriter.getType<Torch::QUInt8Type>();
} else if (qTy.isSignedInteger(8)) {
qTy = rewriter.getType<Torch::QInt8Type>();
} else if (qTy.isSignedInteger(32)) {
qTy = rewriter.getType<Torch::QInt32Type>();
} else {
auto qTensorTy = getQTorchTypeFromTorchIntType(operandTy);
if (!qTensorTy) {
return rewriter.notifyMatchFailure(binder.op,
"unsupported result dtype");
}

auto qTensorTy = rewriter.getType<Torch::ValueTensorType>(
resultType.getOptionalSizes(), qTy);
scale = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(), scale);
zeropoint = rewriter.create<Torch::AtenItemOp>(
Expand Down Expand Up @@ -2272,9 +2263,9 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
// Extract the fill value and dtype
// ONNX requires value attr to be a tensor
if (!attr) {
attr = DenseElementsAttr::get(
resultType.toBuiltinTensor().clone(resultDType),
rewriter.getFloatAttr(resultDType, 0.0));
attr =
DenseElementsAttr::get(resultType.toBuiltinTensor(),
rewriter.getFloatAttr(resultDType, 0.0));
}

// If its a dense resource attr we need to convert to a dense type:
Expand Down
17 changes: 4 additions & 13 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -408,20 +408,11 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getF64FloatAttr(1.0));

auto q = [&](Type qty) -> Type {
if (qty.isSignedInteger(8))
return rewriter.getType<Torch::QInt8Type>();
if (qty.isUnsignedInteger(8))
return rewriter.getType<Torch::QUInt8Type>();
if (qty.isSignedInteger(32))
return rewriter.getType<Torch::QInt32Type>();
return {};
};
auto lhsQTy = getQTorchTypeFromTorchIntType(lhsTy);
auto rhsQTy = getQTorchTypeFromTorchIntType(rhsTy);

Type lhsQTy = rewriter.getType<Torch::ValueTensorType>(
lhsTy.getOptionalSizes(), q(lhsTy.getDtype()));
Type rhsQTy = rewriter.getType<Torch::ValueTensorType>(
rhsTy.getOptionalSizes(), q(rhsTy.getDtype()));
if (!lhsQTy || !rhsQTy)
return rewriter.notifyMatchFailure(binder.op, "failed to get qtype");

lhs = rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
binder.getLoc(), lhsQTy, lhs, scale, lhsZp);
Expand Down
Loading
Loading