Skip to content

Commit

Permalink
[AutoBump] Merge with 41d04a8 (Jun 12)
Browse files Browse the repository at this point in the history
  • Loading branch information
mgehre-amd committed Sep 9, 2024
2 parents 077a2ee + 41d04a8 commit 23b2b30
Show file tree
Hide file tree
Showing 27 changed files with 750 additions and 150 deletions.
13 changes: 13 additions & 0 deletions include/torch-mlir-c/TorchTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,19 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchQUInt8TypeGet(MlirContext context);
/// Gets the !torch.quint8 typeid.
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchQUInt8TypeGetTypeID(void);

//===----------------------------------------------------------------------===//
// torch.qint16 type.
//===----------------------------------------------------------------------===//

/// Checks whether the given type is a !torch.qint16 type
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchQInt16(MlirType t);

/// Gets the !torch.qint16 type.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchQInt16TypeGet(MlirContext context);

/// Gets the !torch.qint16 typeid.
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchQInt16TypeGetTypeID(void);

//===----------------------------------------------------------------------===//
// torch.tensor type.
//===----------------------------------------------------------------------===//
Expand Down
12 changes: 12 additions & 0 deletions include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,18 @@ struct OpBinder {
return success();
}

ParseResult tensorListOperandAtIndex(Value &valueIdx, int64_t idx) {
if (idx >= op->getNumOperands())
return failure();
valueIdx = op->getOperand(idx);
auto tt = dyn_cast<Torch::ListType>(valueIdx.getType());
if (!tt)
return failure();
if (!toValidTensorType(tt.getContainedType()))
return failure();
return success();
}

ParseResult tensorListResultType(Torch::ListType &type0) {
if (op->getNumResults() != 1)
return failure();
Expand Down
2 changes: 1 addition & 1 deletion include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ Value createConstantIntList(OpBinder binder,
ConversionPatternRewriter &rewriter,
SmallVector<int64_t> cstInput);

Type getQTorchTypeFromTorchIntType(Type ty);
Torch::ValueTensorType getQTorchTypeFromTorchIntType(Type ty);

template <typename T>
Value getItemOp(OpBinder binder, ConversionPatternRewriter &rewriter,
Expand Down
10 changes: 10 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/TorchTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,16 @@ def Torch_QInt8Type : Torch_Type<"QInt8", "qint8"> {
}];
}

def Torch_QInt16Type : Torch_Type<"QInt16", "qint16"> {
let summary = "Type modeling `ScalarType::QInt16`, which doesn't yet exist";
let description = [{
Pytorch does not have 16-bit integer quantization support.

This torch type is added to provide a target for 16-bit quantization
schemes coming from imported onnx models.
}];
}

def Torch_QUInt8Type : Torch_Type<"QUInt8", "quint8"> {
let summary = "Type modeling `ScalarType::QUInt8`";
let description = [{
Expand Down
46 changes: 28 additions & 18 deletions include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,24 +86,34 @@ enum class TypeKind {
// at:: and c10:: parts of the macro are never used within the compiler -- we
// only use this for the enum values.
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(_) \
_(uint8_t, Byte) /* 0 */ \
_(int8_t, Char) /* 1 */ \
_(int16_t, Short) /* 2 */ \
_(int, Int) /* 3 */ \
_(int64_t, Long) /* 4 */ \
_(at::Half, Half) /* 5 */ \
_(float, Float) /* 6 */ \
_(double, Double) /* 7 */ \
_(c10::complex<c10::Half>, ComplexHalf) /* 8 */ \
_(c10::complex<float>, ComplexFloat) /* 9 */ \
_(c10::complex<double>, ComplexDouble) /* 10 */ \
_(bool, Bool) /* 11 */ \
_(c10::qint8, QInt8) /* 12 */ \
_(c10::quint8, QUInt8) /* 13 */ \
_(c10::qint32, QInt32) /* 14 */ \
_(at::BFloat16, BFloat16) /* 15 */ \
_(c10::quint4x2, QUInt4x2) /* 16 */ \
_(c10::quint2x4, QUInt2x4) /* 17 */
_(uint8_t, Byte) /* 0 */ \
_(int8_t, Char) /* 1 */ \
_(int16_t, Short) /* 2 */ \
_(int, Int) /* 3 */ \
_(int64_t, Long) /* 4 */ \
_(at::Half, Half) /* 5 */ \
_(float, Float) /* 6 */ \
_(double, Double) /* 7 */ \
_(c10::complex<c10::Half>, ComplexHalf) /* 8 */ \
_(c10::complex<float>, ComplexFloat) /* 9 */ \
_(c10::complex<double>, ComplexDouble) /* 10 */ \
_(bool, Bool) /* 11 */ \
_(c10::qint8, QInt8) /* 12 */ \
_(c10::quint8, QUInt8) /* 13 */ \
_(c10::qint32, QInt32) /* 14 */ \
_(at::BFloat16, BFloat16) /* 15 */ \
_(c10::quint4x2, QUInt4x2) /* 16 */ \
_(c10::quint2x4, QUInt2x4) /* 17 */ \
_(c10::bits1x8, Bits1x8) /* 18 */ \
_(c10::bits2x4, Bits2x4) /* 19 */ \
_(c10::bits4x2, Bits4x2) /* 20 */ \
_(c10::bits8, Bits8) /* 21 */ \
_(c10::bits16, Bits16) /* 22 */ \
_(c10::Float8_e5m2, Float8_e5m2) /* 23 */ \
_(c10::Float8_e4m3fn, Float8_e4m3fn) /* 24 */ \
_(c10::Float8_e5m2fnuz, Float8_e5m2fnuz) /* 25 */ \
_(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) /* 26 */ \
_(c10::qint16, QInt16) /* 27 */

enum class ScalarType : int8_t {
#define DEFINE_ENUM(_1, n) n,
Expand Down
16 changes: 16 additions & 0 deletions lib/CAPI/TorchTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,22 @@ MlirTypeID torchMlirTorchQUInt8TypeGetTypeID() {
return wrap(Torch::QUInt8Type::getTypeID());
}

//===----------------------------------------------------------------------===//
// torch.qint16 type.
//===----------------------------------------------------------------------===//

bool torchMlirTypeIsATorchQInt16(MlirType t) {
return isa<Torch::QInt16Type>(unwrap(t));
}

MlirType torchMlirTorchQInt16TypeGet(MlirContext context) {
return wrap(Torch::QInt16Type::get(unwrap(context)));
}

MlirTypeID torchMlirTorchQInt16TypeGetTypeID() {
return wrap(Torch::QInt16Type::getTypeID());
}

//===----------------------------------------------------------------------===//
// torch.tensor type.
//===----------------------------------------------------------------------===//
Expand Down
33 changes: 12 additions & 21 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
std::numeric_limits<float>::lowest()))
return failure();
auto minSplatAttr = SplatElementsAttr::get(
resultType.toBuiltinTensor().clone(resultDtype),
resultType.toBuiltinTensor(),
rewriter.getFloatAttr(resultDtype, minValue));
min = rewriter.create<Torch::ValueTensorLiteralOp>(
binder.getLoc(), resultType, minSplatAttr);
Expand All @@ -748,7 +748,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
std::numeric_limits<float>::max()))
return failure();
auto maxSplatAttr = SplatElementsAttr::get(
resultType.toBuiltinTensor().clone(resultDtype),
resultType.toBuiltinTensor(),
rewriter.getFloatAttr(resultDtype, maxValue));
max = rewriter.create<Torch::ValueTensorLiteralOp>(
binder.getLoc(), resultType, maxSplatAttr);
Expand Down Expand Up @@ -829,7 +829,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
return success();
});
patterns.onOp(
"Concat", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
"Concat", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
SmallVector<Value> tensors;
int64_t dim;
Expand Down Expand Up @@ -861,7 +861,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
if (binder.op->hasAttr("torch.onnx.value_float") &&
!binder.f32FloatAttr(floatValue, "value_float", 0.0)) {
auto splatAttr =
SplatElementsAttr::get(resultType.toBuiltinTensor().clone(dtype),
SplatElementsAttr::get(resultType.toBuiltinTensor(),
rewriter.getFloatAttr(dtype, floatValue));
rewriter.replaceOpWithNewOp<Torch::ValueTensorLiteralOp>(
binder.op, resultType, splatAttr);
Expand All @@ -872,7 +872,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
if (binder.op->hasAttr("torch.onnx.value_int") &&
!binder.s64IntegerAttr(intValue, "value_int", 0)) {
auto splatAttr =
SplatElementsAttr::get(resultType.toBuiltinTensor().clone(dtype),
SplatElementsAttr::get(resultType.toBuiltinTensor(),
rewriter.getIntegerAttr(dtype, intValue));
rewriter.replaceOpWithNewOp<Torch::ValueTensorLiteralOp>(
binder.op, resultType, splatAttr);
Expand Down Expand Up @@ -932,8 +932,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
for (auto intVal : intValues) {
apValues.push_back(APInt(dtype.getIntOrFloatBitWidth(), intVal));
}
auto attr = DenseElementsAttr::get(
resultType.toBuiltinTensor().clone(dtype), apValues);
auto attr =
DenseElementsAttr::get(resultType.toBuiltinTensor(), apValues);
rewriter.replaceOpWithNewOp<Torch::ValueTensorLiteralOp>(
binder.op, resultType, attr);
return success();
Expand Down Expand Up @@ -1715,21 +1715,12 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
"requires known result dtype");
if (scaleTy.getSizes().size() == 0 ||
(scaleTy.getSizes().size() == 1 && scaleTy.getSizes()[0] == 1)) {
Type qTy = operandTy.getDtype();

if (qTy.isUnsignedInteger(8)) {
qTy = rewriter.getType<Torch::QUInt8Type>();
} else if (qTy.isSignedInteger(8)) {
qTy = rewriter.getType<Torch::QInt8Type>();
} else if (qTy.isSignedInteger(32)) {
qTy = rewriter.getType<Torch::QInt32Type>();
} else {
auto qTensorTy = getQTorchTypeFromTorchIntType(operandTy);
if (!qTensorTy) {
return rewriter.notifyMatchFailure(binder.op,
"unsupported result dtype");
}

auto qTensorTy = rewriter.getType<Torch::ValueTensorType>(
resultType.getOptionalSizes(), qTy);
scale = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(), scale);
zeropoint = rewriter.create<Torch::AtenItemOp>(
Expand Down Expand Up @@ -2272,9 +2263,9 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
// Extract the fill value and dtype
// ONNX requires value attr to be a tensor
if (!attr) {
attr = DenseElementsAttr::get(
resultType.toBuiltinTensor().clone(resultDType),
rewriter.getFloatAttr(resultDType, 0.0));
attr =
DenseElementsAttr::get(resultType.toBuiltinTensor(),
rewriter.getFloatAttr(resultDType, 0.0));
}

// If its a dense resource attr we need to convert to a dense type:
Expand Down
17 changes: 4 additions & 13 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -408,20 +408,11 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getF64FloatAttr(1.0));

auto q = [&](Type qty) -> Type {
if (qty.isSignedInteger(8))
return rewriter.getType<Torch::QInt8Type>();
if (qty.isUnsignedInteger(8))
return rewriter.getType<Torch::QUInt8Type>();
if (qty.isSignedInteger(32))
return rewriter.getType<Torch::QInt32Type>();
return {};
};
auto lhsQTy = getQTorchTypeFromTorchIntType(lhsTy);
auto rhsQTy = getQTorchTypeFromTorchIntType(rhsTy);

Type lhsQTy = rewriter.getType<Torch::ValueTensorType>(
lhsTy.getOptionalSizes(), q(lhsTy.getDtype()));
Type rhsQTy = rewriter.getType<Torch::ValueTensorType>(
rhsTy.getOptionalSizes(), q(rhsTy.getDtype()));
if (!lhsQTy || !rhsQTy)
return rewriter.notifyMatchFailure(binder.op, "failed to get qtype");

lhs = rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
binder.getLoc(), lhsQTy, lhs, scale, lhsZp);
Expand Down
Loading

0 comments on commit 23b2b30

Please sign in to comment.