Skip to content

Commit

Permalink
Merge pull request #350 from Xilinx/bump_to_8d237190
Browse files Browse the repository at this point in the history
[AutoBump] Merge with fixes of 8d23719 (Jun 27) (86)
  • Loading branch information
mgehre-amd authored Sep 16, 2024
2 parents 658586f + ee6e01f commit b46c5b7
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 56 deletions.
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ Value clampFloatHelper(Location loc, Value arg, Value min, Value max,
// Takes the parameters for a clamp and turns it into a series of ops for
// integer inputs.
Value clampIntHelper(Location loc, Value arg, Value min, Value max,
OpBuilder &rewriter, bool isUnsigned = false);
OpBuilder &rewriter, bool isUnsigned);

// Determines whether the integer value falls witin the range of integer type.
bool validIntegerRange(IntegerType ty, int64_t value);
Expand Down
72 changes: 40 additions & 32 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ static Value createLinalgBodyCalculationForElementwiseOp(
Value max = rewriter.create<arith::ConstantIntOp>(
loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
intermediateType);
auto clamp = clampIntHelper(loc, sub, min, max, rewriter);
auto clamp =
clampIntHelper(loc, sub, min, max, rewriter, /*isUnsigned=*/false);

// Truncate to the final value.
return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
Expand Down Expand Up @@ -402,24 +403,26 @@ static Value createLinalgBodyCalculationForElementwiseOp(
int64_t max =
cast<IntegerAttr>(op->getAttr("max_int")).getValue().getSExtValue();

int64_t minRepresentable = std::numeric_limits<int64_t>::min();
int64_t maxRepresentable = std::numeric_limits<int64_t>::max();
if (intTy.isUnsignedInteger()) {
if (intTy.getIntOrFloatBitWidth() > 63) {
(void)rewriter.notifyMatchFailure(
op, "support for 64-bit or larger integers is not implemented");
return {};
minRepresentable = 0;
if (intTy.getIntOrFloatBitWidth() <= 63) {
maxRepresentable = (int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
.getZExtValue();
}
min = std::max(min, (int64_t)0);
max = std::min(max,
(int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
.getZExtValue());
} else {
min =
std::max(min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
.getSExtValue());
max =
std::min(max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
.getSExtValue());
} else if(intTy.getIntOrFloatBitWidth() <= 64) {
// Ensure that min & max fit into signed n-bit constants.
minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
.getSExtValue();
maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
.getSExtValue();
}
// Ensure that the bounds are representable as n-bit signed/unsigned integers.
min = std::max(min, minRepresentable);
max = std::max(max, minRepresentable);
min = std::min(min, maxRepresentable);
max = std::min(max, maxRepresentable);

auto minVal = rewriter.create<arith::ConstantIntOp>(
loc, min, intTy.getIntOrFloatBitWidth());
Expand Down Expand Up @@ -666,10 +669,8 @@ static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor,
}

static SmallVector<Value> expandInputRanks(PatternRewriter &rewriter,
Location loc, Operation *operation,
ValueRange operands) {
auto rank =
cast<RankedTensorType>(operation->getResultTypes().front()).getRank();
Location loc, ValueRange operands,
int64_t rank) {
return llvm::map_to_vector(operands, [&](Value operand) {
return expandRank(rewriter, loc, operand, rank);
});
Expand Down Expand Up @@ -898,10 +899,10 @@ static LogicalResult
emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
Operation *operation, ValueRange operands,
ArrayRef<OpFoldResult> targetShape,
const TypeConverter *converter) {
const TypeConverter &converter) {
// Generate output tensor
auto resultType = cast_or_null<RankedTensorType>(converter->convertType(
cast<RankedTensorType>(operation->getResultTypes().front())));
auto resultType = cast_or_null<RankedTensorType>(
converter.convertType(operation->getResultTypes().front()));
if (!resultType) {
return rewriter.notifyMatchFailure(operation, "failed to convert type");
}
Expand Down Expand Up @@ -953,7 +954,7 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
static LogicalResult
elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
ConversionPatternRewriter &rewriter,
const TypeConverter *converter) {
const TypeConverter &converter) {

// Collect op properties
assert(operation->getNumResults() == 1 && "elementwise op expects 1 result");
Expand All @@ -966,7 +967,9 @@ elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
// Lower operation
IndexPool indexPool;
auto loc = operation->getLoc();
auto expandedOperands = expandInputRanks(rewriter, loc, operation, operands);
auto rank =
cast<RankedTensorType>(operation->getResultTypes().front()).getRank();
auto expandedOperands = expandInputRanks(rewriter, loc, operands, rank);
auto [targetShape, masterOperands] =
computeTargetShape(rewriter, loc, indexPool, expandedOperands);
auto broadcastOperands = broadcastDynamicDimensions(
Expand Down Expand Up @@ -1173,8 +1176,8 @@ class PointwiseConverter : public OpConversionPattern<SrcOp> {
LogicalResult
matchAndRewrite(SrcOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const final {
return elementwiseMatchAndRewriteHelper(op, operands.getOperands(),
rewriter, this->getTypeConverter());
return elementwiseMatchAndRewriteHelper(
op, operands.getOperands(), rewriter, *this->getTypeConverter());
}
};

Expand Down Expand Up @@ -1398,7 +1401,7 @@ class RescaleConverter : public OpConversionPattern<tosa::RescaleOp> {
loc, nestedBuilder.getI32IntegerAttr(intMax));

value = clampIntHelper(nestedLoc, value, intMinVal, intMaxVal,
nestedBuilder);
nestedBuilder, /*isUnsigned=*/false);

if (outIntType.getWidth() < 32) {
value = nestedBuilder.create<arith::TruncIOp>(
Expand Down Expand Up @@ -1772,7 +1775,7 @@ class GenericResizeConverter : public OpConversionPattern<tosa::ResizeOp> {

auto offset = b.create<arith::SelectOp>(pred, one, zeroI32);
val = b.create<arith::AddIOp>(val, offset);
val = clampIntHelper(loc, val, zeroI32, max, b);
val = clampIntHelper(loc, val, zeroI32, max, b, /*isUnsigned=*/false);
return b.create<arith::IndexCastOp>(b.getIndexType(), val);
};

Expand All @@ -1793,8 +1796,10 @@ class GenericResizeConverter : public OpConversionPattern<tosa::ResizeOp> {
Value max, ImplicitLocOpBuilder &b) {
val0 = in;
val1 = b.create<arith::AddIOp>(val0, oneVal);
val0 = clampIntHelper(loc, val0, zeroI32, max, b);
val1 = clampIntHelper(loc, val1, zeroI32, max, b);
val0 =
clampIntHelper(loc, val0, zeroI32, max, b, /*isUnsigned=*/false);
val1 =
clampIntHelper(loc, val1, zeroI32, max, b, /*isUnsigned=*/false);
val0 = b.create<arith::IndexCastOp>(b.getIndexType(), val0);
val1 = b.create<arith::IndexCastOp>(b.getIndexType(), val1);
};
Expand Down Expand Up @@ -2760,7 +2765,10 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
PointwiseConverter<tosa::CeilOp>,
PointwiseConverter<tosa::FloorOp>,
PointwiseConverter<tosa::ClampOp>,
PointwiseConverter<tosa::SigmoidOp>,
PointwiseConverter<tosa::SigmoidOp>
>(converter, patterns->getContext());

patterns->add<
IdentityNConverter<tosa::IdentityOp>,
ReduceConverter<tosa::ReduceAllOp>,
ReduceConverter<tosa::ReduceAnyOp>,
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1110,7 +1110,8 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
auto max = rewriter.create<arith::ConstantIntOp>(
loc, APInt::getSignedMaxValue(outBitwidth).getSExtValue(),
accETy);
auto clamp = clampIntHelper(loc, scaled, min, max, rewriter);
auto clamp = clampIntHelper(loc, scaled, min, max, rewriter,
/*isUnsigned=*/false);

poolVal = clamp;
// Convert type.
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,6 @@ struct TosaToLinalg : public impl::TosaToLinalgBase<TosaToLinalg> {
}

void runOnOperation() override {
TypeConverter converter;
mlir::tosa::populateTosaToLinalgTypeConversion(converter);

RewritePatternSet patterns(&getContext());
ConversionTarget target(getContext());
target.addLegalDialect<linalg::LinalgDialect, tensor::TensorDialect,
Expand All @@ -64,13 +61,16 @@ struct TosaToLinalg : public impl::TosaToLinalgBase<TosaToLinalg> {
target.addLegalOp<tosa::SliceOp>();
target.addLegalOp<tosa::ReshapeOp>();
target.addLegalOp<tosa::PadOp>();
TypeConverter converter;
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
return converter.isSignatureLegal(op.getFunctionType());
});
target.addDynamicallyLegalDialect<func::FuncDialect>(
[&](Operation *op) { return converter.isLegal(op); });
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });

tosa::populateTosaTypeConversion(converter);

FunctionOpInterface func = getOperation();
mlir::tosa::populateTosaToLinalgConversionPatterns(converter, &patterns);
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
Expand Down
9 changes: 0 additions & 9 deletions mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,6 @@ func.func @unranked_add(%arg0 : tensor<10x10xf32> , %arg1 : tensor<10x10xf32>, %

// -----

// CHECK-LABEL: @clamp_on_large_int
func.func @clamp_on_large_int(%arg0: tensor<1xui64>) -> tensor<1xui64> {
// expected-error@+1 {{failed to legalize operation 'tosa.clamp'}}
%0 = tosa.clamp %arg0 {min_int = -1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui64>) -> tensor<1xui64>
return %0 : tensor<1xui64>
}

// -----

// CHECK-LABEL: @rfft2d_with_non_float_type
func.func @rfft2d_with_non_float_type(%arg0 : tensor<1x1x1xi32>) -> (tensor<1x1x1xi32>, tensor<1x1x1xi32>) {
// expected-error@+1 {{failed to legalize operation 'tosa.rfft2d'}}
Expand Down
40 changes: 30 additions & 10 deletions mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,7 @@ func.func @test_simple_ui8(%arg0: tensor<1xui8>) -> () {
// -----

// CHECK-LABEL: @test_simple_i32
func.func @test_simple_i32(%arg0: tensor<1xi32>, %arg1: tensor<1xui32>) -> () {
func.func @test_simple_i32(%arg0: tensor<1xi32>, %unsigned: tensor<1xui32>, %unsigned64: tensor<1xui64>) -> () {
// CHECK: linalg.generic
// CHECK: arith.addi
%0 = tosa.add %arg0, %arg0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
Expand All @@ -674,7 +674,7 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>, %arg1: tensor<1xui32>) -> () {
%40 = tosa.int_div %arg0, %arg0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>

// CHECK: arith.divui
%u4 = tosa.int_div %arg1, %arg1 : (tensor<1xui32>, tensor<1xui32>) -> tensor<1xui32>
%u4 = tosa.int_div %unsigned, %unsigned : (tensor<1xui32>, tensor<1xui32>) -> tensor<1xui32>

// CHECK: linalg.generic
// CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32):
Expand Down Expand Up @@ -708,7 +708,7 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>, %arg1: tensor<1xui32>) -> () {

// CHECK: linalg.generic
// CHECK: arith.shrui
%u11 = tosa.arithmetic_right_shift %arg1, %arg1 {round = 0 : i1} : (tensor<1xui32>, tensor<1xui32>) -> tensor<1xui32>
%u11 = tosa.arithmetic_right_shift %unsigned, %unsigned {round = 0 : i1} : (tensor<1xui32>, tensor<1xui32>) -> tensor<1xui32>

// CHECK: linalg.generic
// CHECK: arith.constant 1
Expand Down Expand Up @@ -736,7 +736,7 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>, %arg1: tensor<1xui32>) -> () {
// CHECK: and
// CHECK: arith.extui
// CHECK: arith.addi
%u12 = tosa.arithmetic_right_shift %arg1, %arg1 {round = 1 : i1} : (tensor<1xui32>, tensor<1xui32>) -> tensor<1xui32>
%u12 = tosa.arithmetic_right_shift %unsigned, %unsigned {round = 1 : i1} : (tensor<1xui32>, tensor<1xui32>) -> tensor<1xui32>

// CHECK: math.ctlz
%13 = tosa.clz %arg0 : (tensor<1xi32>) -> tensor<1xi32>
Expand Down Expand Up @@ -767,12 +767,32 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>, %arg1: tensor<1xui32>) -> () {
%19 = tosa.clamp %0 {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>

// CHECK: linalg.generic
// CHECK: bb0(%[[IN:.*]]: i32,
// CHECK-DAG: %[[LB:.*]] = arith.constant 4 : i32
// CHECK-DAG: %[[UB:.*]] = arith.constant 32 : i32
// CHECK-DAG: arith.maxui %[[LB]],
// CHECK-DAG: arith.minui %[[UB]],
%u0 = tosa.clamp %unsigned {min_int = 4 : i64, max_int = 32 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui32>) -> tensor<1xui32>

// CHECK: linalg.generic
// CHECK-DAG: %[[LB:.*]] = arith.constant -1 : i32
// CHECK-DAG: %[[UB:.*]] = arith.constant -1 : i32
// CHECK-DAG: arith.maxui %[[LB]],
// CHECK-DAG: arith.minui %[[UB]],
%u1 = tosa.clamp %unsigned {min_int = 9223372036854775807 : i64, max_int = 9223372036854775807 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui32>) -> tensor<1xui32>

// CHECK: linalg.generic
// CHECK-DAG: %[[LB:.*]] = arith.constant 0 : i32
// CHECK-DAG: %[[UB:.*]] = arith.constant 5 : i32
// CHECK-DAG: %[[MAX:.*]] = arith.maxui %[[LB]], %[[IN]]
// CHECK-DAG: arith.minui %[[UB]], %[[MAX]]
%u19 = tosa.clamp %arg1 {min_int = -1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui32>) -> tensor<1xui32>
// CHECK-DAG: %[[UB:.*]] = arith.constant 0 : i32
// CHECK-DAG: arith.maxui %[[LB]],
// CHECK-DAG: arith.minui %[[UB]],
%u2 = tosa.clamp %unsigned {min_int = -3 : i64, max_int = -2 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui32>) -> tensor<1xui32>

// CHECK: linalg.generic
// CHECK-DAG: %[[LB:.*]] = arith.constant 0 : i64
// CHECK-DAG: %[[UB:.*]] = arith.constant 9223372036854775807 : i64
// CHECK-DAG: arith.maxui %[[LB]],
// CHECK-DAG: arith.minui %[[UB]],
%u3 = tosa.clamp %unsigned64 {min_int = -3 : i64, max_int = 9223372036854775807 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui64>) -> tensor<1xui64>

// CHECK: linalg.generic
// CHECK: arith.trunci
Expand All @@ -793,7 +813,7 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>, %arg1: tensor<1xui32>) -> () {

// CHECK: linalg.generic
// CHECK: arith.uitofp
%u23 = tosa.cast %arg1 : (tensor<1xui32>) -> tensor<1xf32>
%u23 = tosa.cast %unsigned : (tensor<1xui32>) -> tensor<1xf32>

// CHECK: linalg.generic
// CHECK: arith.constant 0
Expand Down

0 comments on commit b46c5b7

Please sign in to comment.