Skip to content

Commit

Permalink
normalize-memrefs: Normalize memref.alloca
Browse files Browse the repository at this point in the history
The pass was only handling memref.alloc,
and this extends it to also handle memref.alloca.
  • Loading branch information
mgehre-amd committed Oct 1, 2024
1 parent 9054950 commit 3db9abb
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 8 deletions.
8 changes: 7 additions & 1 deletion mlir/include/mlir/Dialect/Affine/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class FuncOp;

namespace memref {
class AllocOp;
class AllocaOp;
} // namespace memref

struct LogicalResult;
Expand Down Expand Up @@ -247,7 +248,12 @@ LogicalResult replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
/// Rewrites the memref defined by this alloc op to have an identity layout map
/// and updates all its indexing uses. Returns failure if any of its uses
/// escape (while leaving the IR in a valid state).
LogicalResult normalizeMemRef(memref::AllocOp *op);
template <typename AllocLikeOp>
LogicalResult normalizeMemRef(AllocLikeOp *op);
extern template LogicalResult
normalizeMemRef<memref::AllocaOp>(memref::AllocaOp *op);
extern template LogicalResult
normalizeMemRef<memref::AllocOp>(memref::AllocOp *op);

/// Normalizes `memrefType` so that the affine layout map of the memref is
/// transformed to an identity map with a new shape being computed for the
Expand Down
21 changes: 14 additions & 7 deletions mlir/lib/Dialect/Affine/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1639,9 +1639,10 @@ static AffineExpr createDimSizeExprForTiledLayout(AffineExpr oldMapOutput,
/// %c4 = arith.constant 4 : index
/// %1 = affine.apply #map1(%c4, %0)
/// %2 = affine.apply #map2(%c4, %0)
template <typename AllocLikeOp>
static void createNewDynamicSizes(MemRefType oldMemRefType,
MemRefType newMemRefType, AffineMap map,
memref::AllocOp *allocOp, OpBuilder b,
AllocLikeOp *allocOp, OpBuilder b,
SmallVectorImpl<Value> &newDynamicSizes) {
// Create new input for AffineApplyOp.
SmallVector<Value, 4> inAffineApply;
Expand Down Expand Up @@ -1688,7 +1689,8 @@ static void createNewDynamicSizes(MemRefType oldMemRefType,
}

// TODO: Currently works for static memrefs with a single layout map.
LogicalResult mlir::affine::normalizeMemRef(memref::AllocOp *allocOp) {
template <typename AllocLikeOp>
LogicalResult mlir::affine::normalizeMemRef(AllocLikeOp *allocOp) {
MemRefType memrefType = allocOp->getType();
OpBuilder b(*allocOp);

Expand All @@ -1704,7 +1706,7 @@ LogicalResult mlir::affine::normalizeMemRef(memref::AllocOp *allocOp) {

SmallVector<Value, 4> symbolOperands(allocOp->getSymbolOperands());
AffineMap layoutMap = memrefType.getLayout().getAffineMap();
memref::AllocOp newAlloc;
AllocLikeOp newAlloc;
// Check if `layoutMap` is a tiled layout. Only single layout map is
// supported for normalizing dynamic memrefs.
SmallVector<std::tuple<AffineExpr, unsigned, unsigned>> tileSizePos;
Expand All @@ -1716,11 +1718,11 @@ LogicalResult mlir::affine::normalizeMemRef(memref::AllocOp *allocOp) {
newDynamicSizes);
// Add the new dynamic sizes in new AllocOp.
newAlloc =
b.create<memref::AllocOp>(allocOp->getLoc(), newMemRefType,
newDynamicSizes, allocOp->getAlignmentAttr());
b.create<AllocLikeOp>(allocOp->getLoc(), newMemRefType, newDynamicSizes,
allocOp->getAlignmentAttr());
} else {
newAlloc = b.create<memref::AllocOp>(allocOp->getLoc(), newMemRefType,
allocOp->getAlignmentAttr());
newAlloc = b.create<AllocLikeOp>(allocOp->getLoc(), newMemRefType,
allocOp->getAlignmentAttr());
}
// Replace all uses of the old memref.
if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc,
Expand All @@ -1745,6 +1747,11 @@ LogicalResult mlir::affine::normalizeMemRef(memref::AllocOp *allocOp) {
return success();
}

template LogicalResult
mlir::affine::normalizeMemRef<memref::AllocaOp>(memref::AllocaOp *op);
template LogicalResult
mlir::affine::normalizeMemRef<memref::AllocOp>(memref::AllocOp *op);

MemRefType mlir::affine::normalizeMemRefType(MemRefType memrefType) {
unsigned rank = memrefType.getRank();
if (rank == 0)
Expand Down
16 changes: 16 additions & 0 deletions mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,17 @@ bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) {
.wasInterrupted())
return false;

if (funcOp
.walk([&](memref::AllocaOp allocaOp) -> WalkResult {
Value oldMemRef = allocaOp.getResult();
if (!allocaOp.getType().getLayout().isIdentity() &&
!isMemRefNormalizable(oldMemRef.getUsers()))
return WalkResult::interrupt();
return WalkResult::advance();
})
.wasInterrupted())
return false;

if (funcOp
.walk([&](func::CallOp callOp) -> WalkResult {
for (unsigned resIndex :
Expand Down Expand Up @@ -347,6 +358,11 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp,
for (memref::AllocOp allocOp : allocOps)
(void)normalizeMemRef(&allocOp);

SmallVector<memref::AllocaOp, 4> allocaOps;
funcOp.walk([&](memref::AllocaOp op) { allocaOps.push_back(op); });
for (memref::AllocaOp allocaOp : allocaOps)
(void)normalizeMemRef(&allocaOp);

// We use this OpBuilder to create new memref layout later.
OpBuilder b(funcOp);

Expand Down
14 changes: 14 additions & 0 deletions mlir/test/Dialect/MemRef/normalize-memrefs.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,20 @@ func.func @permute() {
// CHECK-NEXT: memref.dealloc [[MEM]]
// CHECK-NEXT: return

// CHECK-LABEL: func @alloca
func.func @alloca(%idx : index) {
// CHECK-NEXT: memref.alloca() : memref<65xf32>
%A = memref.alloca() : memref<64xf32, affine_map<(d0) -> (d0 + 1)>>
// CHECK-NEXT: affine.load %{{.*}}[symbol(%arg0) + 1] : memref<65xf32>
affine.load %A[%idx] : memref<64xf32, affine_map<(d0) -> (d0 + 1)>>
affine.for %i = 0 to 64 {
%1 = affine.load %A[%i] : memref<64xf32, affine_map<(d0) -> (d0 + 1)>>
"prevent.dce"(%1) : (f32) -> ()
// CHECK: %{{.*}} = affine.load %{{.*}}[%arg{{.*}} + 1] : memref<65xf32>
}
return
}

// CHECK-LABEL: func @shift
func.func @shift(%idx : index) {
// CHECK-NEXT: memref.alloc() : memref<65xf32>
Expand Down

0 comments on commit 3db9abb

Please sign in to comment.