Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bump externals/llvm-project from c9c2863 to 14e4586 #426

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
322 changes: 261 additions & 61 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,119 @@ class ConvertAtenMulOp : public OpConversionPattern<AtenOpT> {
}
};

// Function to perform division with trunc rounding mode (rounding result
// towards zero) for float type inputs.
// This function takes in the division result between lhs and rhs rather
// than takes in the original lhs and rhs tensors as parameters.
Value truncFloatDivWithDivResult(PatternRewriter &rewriter, Operation *op,
TensorType outType, Value divResult) {
// To implement trunc mode for float inputs, multiply the floored abs
// of the tensor with the elementwise signedness of the tensor.
// div_result = lhs / rhs
// trunc_val = floor(abs(div_result)) * sign(div_result)
auto zero =
tosa::getConstTensor<float>(rewriter, op, 0, {}, outType.getElementType())
.value();

auto one =
tosa::getConstTensor<float>(rewriter, op, 1, {}, outType.getElementType())
.value();

auto minusOne = tosa::getConstTensor<float>(rewriter, op, -1, {},
outType.getElementType())
.value();

auto cond = rewriter.create<tosa::GreaterEqualOp>(
op->getLoc(),
RankedTensorType::get(outType.getShape(), rewriter.getIntegerType(1)),
divResult, zero);

auto selectOp = rewriter.create<tosa::SelectOp>(op->getLoc(), outType, cond,
one, minusOne);

auto absDivResult =
rewriter.create<tosa::AbsOp>(op->getLoc(), outType, divResult);

auto flooredAbsDivResult =
rewriter.create<tosa::FloorOp>(op->getLoc(), outType, absDivResult);

Value result =
tosa::createMulOpAndCast(rewriter, op, outType, flooredAbsDivResult,
selectOp, /*shift=*/0)
.getResult();

return result;
}

// Function to perform division with trunc rounding mode (rounding result
// towards zero) for float type inputs
Value truncFloatDiv(PatternRewriter &rewriter, Operation *op,
TensorType outType, Value lhs, Value rhs) {
rhs = tosa::promoteType(rewriter, rhs, outType);

auto rhsRcp =
rewriter.create<tosa::ReciprocalOp>(op->getLoc(), rhs.getType(), rhs);

auto divResult = tosa::createMulOpAndCast(rewriter, op, outType, lhs, rhsRcp,
/*shift=*/0);

return truncFloatDivWithDivResult(rewriter, op, outType, divResult);
}

// Function to perform division with floor rounding mode (rounding result
// down) for integer type inputs.
Value floorIntDiv(PatternRewriter &rewriter, Operation *op, TensorType outType,
Value lhs, Value rhs) {
// To implement floor mode int input, utilize tosa::IntDivOp (trunc div
// result) with the following formula elementwise:
// floor_val = trunc_val - ((trunc_val * rhs != lhs)
// && (sign(lhs) != sign(rhs)))

// TOSA IntDiv requires inputs to be i32
auto i32Type =
RankedTensorType::get(outType.getShape(), rewriter.getIntegerType(32));
lhs = tosa::promoteType(rewriter, lhs, i32Type);
rhs = tosa::promoteType(rewriter, rhs, i32Type);

auto intDivOp =
rewriter.create<tosa::IntDivOp>(op->getLoc(), i32Type, lhs, rhs);

auto zero = tosa::getConstTensor<int32_t>(rewriter, op, 0, {}).value();

auto one = tosa::getConstTensor<int32_t>(rewriter, op, 1, {}).value();

auto boolType =
RankedTensorType::get(outType.getShape(), rewriter.getIntegerType(1));

auto lhsMulRhs = rewriter.create<tosa::MulOp>(op->getLoc(), i32Type, lhs, rhs,
/*shift=*/0);

auto lhsRhsDifferentSign =
rewriter.create<tosa::GreaterOp>(op->getLoc(), boolType, zero, lhsMulRhs);

auto truncMulRhs = rewriter.create<tosa::MulOp>(op->getLoc(), i32Type,
intDivOp, rhs, /*shift=*/0);

auto truncMulRhsEqualLhs =
rewriter.create<tosa::EqualOp>(op->getLoc(), boolType, truncMulRhs, lhs);

auto truncMulRhsNotEqualLhs = rewriter.create<tosa::LogicalNotOp>(
op->getLoc(), boolType, truncMulRhsEqualLhs);

auto truncMinusOne =
rewriter.create<tosa::SubOp>(op->getLoc(), i32Type, intDivOp, one);

auto cond = rewriter.create<tosa::LogicalAndOp>(
op->getLoc(), boolType, lhsRhsDifferentSign, truncMulRhsNotEqualLhs);

auto selectOp = rewriter.create<tosa::SelectOp>(op->getLoc(), i32Type, cond,
truncMinusOne, intDivOp);

Value result = tosa::promoteType(rewriter, selectOp, outType);

return result;
}

template <typename AtenOpT>
class ConvertAtenDivOp : public OpConversionPattern<AtenOpT> {
public:
Expand Down Expand Up @@ -502,25 +615,64 @@ class ConvertAtenDivOp : public OpConversionPattern<AtenOpT> {
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()));

// auto result;
// Get rounding mode for aten.div.Tensor_mode
std::string roundMode;
if constexpr (std::is_same<AtenOpT, AtenDivTensorModeOp>() ||
std::is_same<AtenOpT, AtenDivScalarModeOp>()) {
if (!matchPattern(op.getRoundingMode(), m_TorchConstantStr(roundMode)))
return rewriter.notifyMatchFailure(
op, "Non-const rounding mode parameter unsupported");
}

Value result;
if (isa<mlir::FloatType>(outType.getElementType())) {
// The input to the reciprocal is an integer sometimes, and we may need to
// promote it to a floating point. Per TOSA specification, the input types
// can only be floating point for tosa::ReciprocalOp.
Value rhsCasted = tosa::promoteType(rewriter, rhsTensor, outType);
auto rcpOp = rewriter.create<tosa::ReciprocalOp>(
op->getLoc(), rhsCasted.getType(), rhsCasted);

result = tosa::createMulOpAndCast(rewriter, op, outType, lhs,
rcpOp.getResult(), /*shift=*/0)
.getResult();
// The input to the reciprocal is an integer sometimes, and we may need
// to promote it to a floating point. Per TOSA specification, the input
// types can only be floating point for tosa::ReciprocalOp.
rhsTensor = tosa::promoteType(rewriter, rhsTensor, outType);
auto rhsRcp = rewriter.create<tosa::ReciprocalOp>(
op->getLoc(), rhsTensor.getType(), rhsTensor);

auto divResult = tosa::createMulOpAndCast(rewriter, op, outType, lhs,
rhsRcp, /*shift=*/0);

// Round result based on rounding mode
if (roundMode.compare("floor") == 0) {
// "floor": rounds the results of the division down. Equivalent to
// floor division in Python (the // operator).
auto floorOp =
rewriter.create<tosa::FloorOp>(op->getLoc(), outType, divResult);

result = floorOp.getResult();
} else if (roundMode.compare("trunc") == 0) {
// "trunc": rounds the results of the division towards zero. Equivalent
// to C-style integer division.
result = truncFloatDivWithDivResult(rewriter, op, outType, divResult);
} else {
// None: No rounding mode
result = divResult.getResult();
}
} else {
// The output type can be different than the input types (e.g. dividing an
// int tensor results in a floating point tensor).
result = tosa::createBinaryOpAndCast<tosa::IntDivOp>(
rewriter, op, outType, lhs, rhsTensor)
.getResult();
if (roundMode.compare("floor") == 0) {
// "floor": rounds the results of the division down. Equivalent to floor
// division in Python (the // operator).
result = floorIntDiv(rewriter, op, outType, lhs, rhsTensor);
} else {
// "trunc": rounds the results of the division towards zero. Equivalent
// to C-style integer division.
// None: no rounding mode.

// TOSA IntDiv requires inputs to be i32
auto i32Type = RankedTensorType::get(outType.getShape(),
rewriter.getIntegerType(32));
lhs = tosa::promoteType(rewriter, lhs, i32Type);
rhsTensor = tosa::promoteType(rewriter, rhsTensor, i32Type);

auto intDivOp = rewriter.create<tosa::IntDivOp>(op->getLoc(), i32Type,
lhs, rhsTensor);

result = tosa::promoteType(rewriter, intDivOp, outType);
}
}

rewriter.replaceOp(op, {result});
Expand Down Expand Up @@ -5092,56 +5244,94 @@ LogicalResult ConvertAtenOp<AtenToDtypeOp>::matchAndRewrite(
return success();
}

template <>
LogicalResult ConvertAtenOp<AtenRemainderScalarOp>::matchAndRewrite(
AtenRemainderScalarOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
template <typename AtenOpT>
class ConvertAtenRemainderFmodOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

Value self = adaptor.getSelf();
auto selfTy = cast<RankedTensorType>(self.getType());
Value self = adaptor.getSelf();
auto selfTy = cast<RankedTensorType>(self.getType());

if (!selfTy)
return rewriter.notifyMatchFailure(
op, "Only ranked tensor types supported in TOSA Remainder");
if (!selfTy)
return rewriter.notifyMatchFailure(
op, "Only ranked tensor types supported in TOSA Remainder/Fmod");

auto outType =
cast<TensorType>(getTypeConverter()->convertType(op.getType()));
auto outType =
cast<TensorType>(this->getTypeConverter()->convertType(op.getType()));

Type outElemTy = outType.getElementType();
if (!outElemTy.isIntOrFloat())
return rewriter.notifyMatchFailure(
op, "Only floating-point or integer datatype legalization supported");
Type outElemTy = outType.getElementType();
if (!outElemTy.isIntOrFloat())
return rewriter.notifyMatchFailure(
op, "Only floating-point or integer datatype legalization supported");

Value otherTensor;
Value other = op.getOther();
if (failed(torchScalarToTosaTensor(rewriter, op, other, otherTensor,
outElemTy, {})))
return rewriter.notifyMatchFailure(
op, "Currently only scalar constants are supported for "
"conversion in TOSA Remainder operation");

if (selfTy.getElementType() != outElemTy)
self = rewriter.create<tosa::CastOp>(op.getLoc(), outType, self);

auto divTensor = self;
if (isa<mlir::FloatType>(outElemTy)) {
auto otherTensorReciprocal = rewriter.create<tosa::ReciprocalOp>(
op.getLoc(), otherTensor.getType(), otherTensor);
divTensor = rewriter.create<tosa::MulOp>(
op.getLoc(), outType, self, otherTensorReciprocal, /*shift=*/0);
divTensor = rewriter.create<tosa::FloorOp>(op.getLoc(), outType, divTensor);
} else {
divTensor = rewriter.create<tosa::IntDivOp>(op.getLoc(), outType, self,
otherTensor);
}
Value otherTensor;
if constexpr (std::is_same<AtenOpT, AtenRemainderScalarOp>()) {
Value other = op.getOther();
if (failed(torchScalarToTosaTensor(rewriter, op, other, otherTensor,
outElemTy, {})))
return rewriter.notifyMatchFailure(
op, "Currently only scalar constants are supported for "
"conversion in TOSA Remainder/Fmod operation");
} else {
otherTensor = adaptor.getOther();
auto otherTy = cast<RankedTensorType>(otherTensor.getType());

auto mulTensor =
rewriter.create<tosa::MulOp>(op.getLoc(), outType, otherTensor, divTensor,
/*shift=*/0);
rewriter.replaceOpWithNewOp<tosa::SubOp>(op, outType, self, mulTensor);
if (!otherTy)
return rewriter.notifyMatchFailure(
op, "Only ranked tensor types supported in TOSA Remainder/Fmod");
}

return success();
}
constexpr bool isRemainderOp =
std::is_same<AtenOpT, AtenRemainderScalarOp>() ||
std::is_same<AtenOpT, AtenRemainderTensorOp>() ||
std::is_same<AtenOpT, AtenRemainderIntOp>();

if (selfTy.getElementType() != outElemTy)
self = rewriter.create<tosa::CastOp>(op.getLoc(), outType, self);

Value divTensor;
if (isRemainderOp) {
// torch.remainder(a, b) == a - a.div(b, rounding_mode="floor") * b
if (isa<mlir::FloatType>(outElemTy)) {
auto otherTensorReciprocal = rewriter.create<tosa::ReciprocalOp>(
op.getLoc(), otherTensor.getType(), otherTensor);
divTensor = rewriter.create<tosa::MulOp>(
op.getLoc(), outType, self, otherTensorReciprocal, /*shift=*/0);
divTensor =
rewriter.create<tosa::FloorOp>(op.getLoc(), outType, divTensor);
} else {
divTensor = floorIntDiv(rewriter, op, outType, self, otherTensor);
}
} else {
// torch.fmod(a, b) == a - a.div(b, rounding_mode="trunc") * b
if (isa<mlir::FloatType>(outElemTy)) {
divTensor = truncFloatDiv(rewriter, op, outType, self, otherTensor);
} else {
// TOSA IntDiv requires inputs to be i32
auto i32Type = RankedTensorType::get(outType.getShape(),
rewriter.getIntegerType(32));
self = tosa::promoteType(rewriter, self, i32Type);
otherTensor = tosa::promoteType(rewriter, otherTensor, i32Type);

auto intDivTensor = rewriter.create<tosa::IntDivOp>(
op->getLoc(), i32Type, self, otherTensor);

divTensor = tosa::promoteType(rewriter, intDivTensor, outType);
}
}

auto mulTensor = rewriter.create<tosa::MulOp>(op.getLoc(), outType,
otherTensor, divTensor,
/*shift=*/0);
rewriter.replaceOpWithNewOp<tosa::SubOp>(op, outType, self, mulTensor);

return success();
}
};

template <typename AtenOpT, typename TosaOpT>
class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
Expand Down Expand Up @@ -6546,11 +6736,11 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
#define INSERT_BINARY_COMPARE_PATTERN(AtenOp, TosaOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenCompareOp<AtenOp, TosaOp>>(typeConverter, context);
INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp, tosa::GreaterOp)
INSERT_BINARY_COMPARE_PATTERN(AtenGeScalarOp, tosa::GreaterEqualOp)
INSERT_BINARY_COMPARE_PATTERN(AtenGeTensorOp, tosa::GreaterEqualOp)
INSERT_BINARY_COMPARE_PATTERN(AtenLeScalarOp, tosa::GreaterEqualOp)
INSERT_BINARY_COMPARE_PATTERN(AtenLeTensorOp, tosa::GreaterEqualOp)
INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp, tosa::GreaterOp)
INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp, tosa::GreaterOp)
INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp, tosa::GreaterOp)
INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp, tosa::GreaterOp)
Expand All @@ -6575,8 +6765,19 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
patterns.add<ConvertAtenDivOp<AtenOp>>(typeConverter, context);
INSERT_BINARY_DIV_PATTERN(AtenDivTensorOp);
INSERT_BINARY_DIV_PATTERN(AtenDivScalarOp);
INSERT_BINARY_DIV_PATTERN(AtenDivTensorModeOp);
INSERT_BINARY_DIV_PATTERN(AtenDivScalarModeOp);
#undef INSERT_BINARY_DIV_PATTERN

#define INSERT_REMAINDER_FMOD_OP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenRemainderFmodOp<AtenOp>>(typeConverter, context);
INSERT_REMAINDER_FMOD_OP_PATTERN(AtenRemainderScalarOp);
INSERT_REMAINDER_FMOD_OP_PATTERN(AtenRemainderTensorOp);
INSERT_REMAINDER_FMOD_OP_PATTERN(AtenFmodScalarOp);
INSERT_REMAINDER_FMOD_OP_PATTERN(AtenFmodTensorOp);
#undef INSERT_REMAINDER_FMOD_OP_PATTERN

#define INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenMultipleDimsReductionOp<AtenOp, ConversionFunc>>( \
Expand Down Expand Up @@ -6732,7 +6933,6 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
INSERT_ATENOP_PATTERN(AtenCopyOp);
INSERT_ATENOP_PATTERN(AtenToDtypeOp);
INSERT_ATENOP_PATTERN(AtenConstantPadNdOp);
INSERT_ATENOP_PATTERN(AtenRemainderScalarOp);
INSERT_ATENOP_PATTERN(AtenCatOp);
INSERT_ATENOP_PATTERN(AtenSqrtOp);
INSERT_ATENOP_PATTERN(AtenEmptyMemoryFormatOp);
Expand Down
Loading
Loading