Skip to content

Commit

Permalink
Use IntegerRangeAnalysis to get bounds of allocation. (iree-org#18991)
Browse files Browse the repository at this point in the history
Towards iree-org#18973

Signed-off-by: MaheshRavishankar <[email protected]>
  • Loading branch information
MaheshRavishankar authored Nov 1, 2024
1 parent 046a705 commit 3bb7fd2
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 7 deletions.
47 changes: 42 additions & 5 deletions compiler/src/iree/compiler/Codegen/Common/PadDynamicAlloc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/PatternMatch.h"
Expand Down Expand Up @@ -37,22 +40,47 @@ static Value skipAffineMaxZero(Value dim) {
return *affineMax.getSymbolOperands().begin();
}

static LogicalResult padAlloc(MLIRContext *context, memref::AllocOp allocOp) {
static FailureOr<int64_t> getUpperBound(Value dim,
const DataFlowSolver &solver) {
// Check the integer range analysis.
if (auto *maybeRange =
solver.lookupState<dataflow::IntegerValueRangeLattice>(dim)) {
IntegerValueRange range = maybeRange->getValue();
if (!range.isUninitialized() &&
range.getValue().smax() !=
IntegerValueRange::getMaxRange(dim).getValue().smax()) {
return range.getValue().smax().getSExtValue();
}
}

// Check the value bounds constraint set.
// TODO: These two analysis could be merged, but probably needs
// to happen usptream.
auto ub = ValueBoundsConstraintSet::computeConstantBound(
presburger::BoundType::UB, {dim, /*dim=*/std::nullopt},
/*stopCondition=*/nullptr, /*closedUB=*/true);
if (succeeded(ub)) {
return ub.value();
}
return failure();
}

static LogicalResult padAlloc(MLIRContext *context, memref::AllocOp allocOp,
const DataFlowSolver &solver) {
IRRewriter rewriter(context);
rewriter.setInsertionPoint(allocOp);
SmallVector<int64_t> shape = llvm::to_vector(allocOp.getType().getShape());
SmallVector<OpFoldResult> sizes;
size_t dynamicDimIdx = 0;

for (int64_t &dimSize : shape) {
if (!ShapedType::isDynamic(dimSize)) {
sizes.push_back(rewriter.getIndexAttr(dimSize));
continue;
}
Value dim = allocOp.getDynamicSizes()[dynamicDimIdx++];
dim = skipAffineMaxZero(dim);
auto ub = ValueBoundsConstraintSet::computeConstantBound(
presburger::BoundType::UB, {dim, /*dim=*/std::nullopt},
/*stopCondition=*/nullptr, /*closedUB=*/true);
FailureOr<int64_t> ub = getUpperBound(dim, solver);
if (failed(ub)) {
return allocOp.emitOpError(
"unexpected allocation without upper bound shapes");
Expand Down Expand Up @@ -84,11 +112,20 @@ struct PadDynamicAllocPass final
auto funcOp = getOperation();
MLIRContext *context = &getContext();
SmallVector<memref::AllocOp> sharedMemAllocs;

DataFlowSolver solver;
solver.load<dataflow::DeadCodeAnalysis>();
solver.load<dataflow::IntegerRangeAnalysis>();
if (failed(solver.initializeAndRun(funcOp))) {
funcOp.emitOpError("failed to run integer range analysis");
return signalPassFailure();
}

// Collect all the alloc operations.
funcOp.walk(
[&](memref::AllocOp allocOp) { sharedMemAllocs.push_back(allocOp); });
for (memref::AllocOp alloc : sharedMemAllocs) {
if (failed(padAlloc(context, alloc)))
if (failed(padAlloc(context, alloc, solver)))
return signalPassFailure();
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-pad-dynamic-alloc))" %s | FileCheck %s
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-pad-dynamic-alloc))" --split-input-file --mlir-print-local-scope %s | FileCheck %s

// CHECK-LABEL: dynamic_alloc
func.func @dynamic_alloc(%id : index) {
Expand All @@ -13,6 +13,8 @@ func.func @dynamic_alloc(%id : index) {
return
}

// -----

// CHECK-LABEL: dynamic_alloc_max_0
func.func @dynamic_alloc_max_0(%id : index) {
%c0 = arith.constant 0 : index
Expand All @@ -26,3 +28,13 @@ func.func @dynamic_alloc_max_0(%id : index) {
vector.store %cst, %0[%c0, %c0, %c0] : memref<1x?x32xf32, 3>, vector<4xf32>
return
}

// -----

func.func @dynamic_bound_alloc(%id : index) {
%0 = util.assume.int %id<umin = 0, umax = 4088> : index
%1 = memref.alloc(%0) : memref<?xf32, 3>
return
}
// CHECK-LABEL: func @dynamic_bound_alloc(
// CHECK: %alloc = memref.alloc() : memref<4088xf32, 3>
2 changes: 1 addition & 1 deletion compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1142,6 +1142,7 @@ static void addLowerToLLVMGPUPasses(OpPassManager &modulePassManager,
modulePassManager.addPass(createStripDebugInfoPass());
// Cast address spaces of all function arguments to generic.
modulePassManager.addPass(createLLVMGPUCastAddressSpaceFunctionPass());
modulePassManager.addPass(IREE::Util::createDropCompilerHintsPass());

if (forROCDL) {
// convert to ROCDL.
Expand Down Expand Up @@ -1200,7 +1201,6 @@ void buildLLVMGPUCodegenPassPipeline(OpPassManager &variantPassManager,
.addPass(createLLVMGPULowerExecutableTargetPass);
}
variantPassManager.addPass(createReconcileTranslationInfoPass());
variantPassManager.addPass(IREE::Util::createDropCompilerHintsPass());

//===--------------------------------------------------------------------===//
// Convert Linalg ops to LLVM+NVVM/ROCDL ops.
Expand Down

0 comments on commit 3bb7fd2

Please sign in to comment.