Skip to content

Commit

Permalink
[Codegen] Support inferring scalable vectors for remainder loops (ire…
Browse files Browse the repository at this point in the history
…e-org#17998)

Follow on for iree-org#17891 that extends the scalable vector size inference to
work for remainder loops. The upper bound of a scalable remainder dim is
an expression of the form `(vscale * n) + cst` (cst <= 0).

This patch adds some simple pattern matching that rounds up the upper
bound by removing the `+ cst`, which allows matching the scalable
quantity.

---------

Signed-off-by: Benjamin Maxwell <[email protected]>
  • Loading branch information
MacDue authored Jul 29, 2024
1 parent b8370b8 commit 80c33b0
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ inferSizesFromIR(linalg::LinalgOp linalgOp, std::optional<OpResult> opResult) {
for (auto operandDimPair : operandDimPairs) {
Value operand = operandDimPair.first;
unsigned operandDim = operandDimPair.second;
maybeDimBound = computeDimUpperBound(operand, operandDim, vscaleRange);
maybeDimBound = computeDimUpperBound(operand, operandDim, vscaleRange,
RoundUpVscaleMultiple::Yes);
if (succeeded(maybeDimBound)) {
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -405,3 +405,43 @@ func.func @dynamic_fill_with_scalable_tiling_infer_vector_size(%arg0: tensor<1x6
// CHECK-MASK: scf.for
// CHECK-MASK: scf.for
// CHECK-MASK: vector.transfer_write %[[CST]], {{.*}} {in_bounds = [true, true, true, true]} : vector<1x1x4x[4]xf32>, tensor<1x1x4x?xf32>

// -----

#aarch64_sve = #hal.executable.target<"llvm-cpu", "embedded-elf-arm_64", {cpu_features = "+sve", target_triple = "aarch64-none-elf"}>

func.func @dynamic_fill_with_scalable_tiling_infer_remainder_vector_size(%arg0: tensor<1x67x120x100xf32>) -> tensor<1x67x120x100xf32>
attributes {hal.executable.target = #aarch64_sve}
{
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%c100 = arith.constant 100 : index
%c67 = arith.constant 67 : index
%c120 = arith.constant 120 : index
%cst = arith.constant 0.000000e+00 : f32
%vscale = vector.vscale
%c4_vscale = arith.muli %vscale, %c4 : index
%0 = scf.for %arg1 = %c0 to %c67 step %c1 iter_args(%arg2 = %arg0) -> (tensor<1x67x120x100xf32>) {
%1 = scf.for %arg3 = %c0 to %c120 step %c4 iter_args(%arg4 = %arg2) -> (tensor<1x67x120x100xf32>) {
%rem_start = affine.apply affine_map<()[s0] -> (-(100 mod s0) + 100)>()[%c4_vscale]
%3 = scf.for %arg5 = %rem_start to %c100 step %c4_vscale iter_args(%arg6 = %arg4) -> (tensor<1x67x120x100xf32>) {
%rem_elts = affine.apply affine_map<(d0) -> (-d0 + 100)>(%arg5)
%extracted_slice = tensor.extract_slice %arg6[0, %arg1, %arg3, %arg5] [1, 1, 4, %rem_elts] [1, 1, 1, 1] : tensor<1x67x120x100xf32> to tensor<1x1x4x?xf32>
%4 = linalg.fill ins(%cst : f32) outs(%extracted_slice : tensor<1x1x4x?xf32>) -> tensor<1x1x4x?xf32>
%inserted_slice = tensor.insert_slice %4 into %arg6[0, %arg1, %arg3, %arg5] [1, 1, 4, %rem_elts] [1, 1, 1, 1] : tensor<1x1x4x?xf32> into tensor<1x67x120x100xf32>
scf.yield %inserted_slice : tensor<1x67x120x100xf32>
}
scf.yield %3 : tensor<1x67x120x100xf32>
}
scf.yield %1 : tensor<1x67x120x100xf32>
}
return %0 : tensor<1x67x120x100xf32>
}

// CHECK-MASK-LABEL: func.func @dynamic_fill_with_scalable_tiling_infer_remainder_vector_size
// CHECK-MASK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x4x[4]xf32>
// CHECK-MASK: scf.for
// CHECK-MASK: scf.for
// CHECK-MASK: scf.for
// CHECK-MASK: vector.transfer_write %[[CST]], {{.*}} {in_bounds = [true, true, true, true]} : vector<1x1x4x[4]xf32>, tensor<1x1x4x?xf32>
25 changes: 21 additions & 4 deletions compiler/src/iree/compiler/Codegen/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1159,7 +1159,8 @@ getDefaultVscaleRange(IREE::HAL::ExecutableTargetAttr targetAttr) {

FailureOr<DimBoundSize>
computeDimUpperBound(Value shapedValue, unsigned dimNum,
std::optional<VscaleRange> vscaleRange) {
std::optional<VscaleRange> vscaleRange,
RoundUpVscaleMultiple roundUp) {
if (!vscaleRange.has_value()) {
FailureOr<int64_t> maybeDimBoundSize =
ValueBoundsConstraintSet::computeConstantBound(
Expand All @@ -1175,9 +1176,25 @@ computeDimUpperBound(Value shapedValue, unsigned dimNum,
shapedValue, dimNum,
/*vscaleMin=*/vscaleRange->min,
/*vscaleMax=*/vscaleRange->max, presburger::BoundType::UB);
if (succeeded(maybeDimBound))
return maybeDimBound->getSize();
return failure();
if (failed(maybeDimBound))
return failure();
auto boundSize = maybeDimBound->getSize();
if (succeeded(boundSize))
return boundSize;
if (roundUp == RoundUpVscaleMultiple::No)
return failure();
// If the upper bound map is of the form `add(subExpr, cst)` (cst <= 0),
// round it up to `subExpr` (and try matching the bound again).
auto binOp = dyn_cast<AffineBinaryOpExpr>(maybeDimBound->map.getResult(0));
if (!binOp || binOp.getKind() != AffineExprKind::Add)
return failure();
auto cst = dyn_cast<AffineConstantExpr>(binOp.getRHS());
if (!cst || cst.getValue() > 0)
return failure();
DimBound roundedDimBound{AffineMap::get(maybeDimBound->map.getNumDims(),
maybeDimBound->map.getNumSymbols(),
binOp.getLHS())};
return roundedDimBound.getSize();
}

} // namespace mlir::iree_compiler
7 changes: 6 additions & 1 deletion compiler/src/iree/compiler/Codegen/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,12 +237,17 @@ getDefaultVscaleRange(IREE::HAL::ExecutableTargetAttr targetAttr);
using DimBound = vector::ConstantOrScalableBound;
using DimBoundSize = DimBound::BoundSize;

/// Should the scalable upper bound be rounded up to the nearest multiple of
/// vscale?
enum class RoundUpVscaleMultiple { No, Yes };

/// Computes the upper bound of `dimNum` dim of the ShapedType value
/// `shapedValue`. If the optional `vscaleRange` is provided then the computed
/// bound can be a scalable quantity.
FailureOr<DimBoundSize>
computeDimUpperBound(Value shapedValue, unsigned dimNum,
std::optional<VscaleRange> vscaleRange);
std::optional<VscaleRange> vscaleRange,
RoundUpVscaleMultiple = RoundUpVscaleMultiple::No);

} // namespace mlir::iree_compiler

Expand Down

0 comments on commit 80c33b0

Please sign in to comment.