Skip to content

Commit

Permalink
[MemRef] Migrate away from PointerUnion::{is,get} (NFC) (llvm#120202)
Browse files Browse the repository at this point in the history
Note that PointerUnion::{is,get} have been soft deprecated in
PointerUnion.h:

  // FIXME: Replace the uses of is(), get() and dyn_cast() with
  //        isa<T>, cast<T> and the llvm::dyn_cast<T>

I'm not touching PointerUnion::dyn_cast for now because it's a bit
complicated; we could blindly migrate it to dyn_cast_if_present, but
we should probably use dyn_cast when the operand is known to be
non-null.
  • Loading branch information
kazutakahirata authored Dec 17, 2024
1 parent 345a352 commit 30916b6
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 13 deletions.
12 changes: 5 additions & 7 deletions mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ static void constifyIndexValues(
values[it.index()] = builder.getIndexAttr(constValue);
}
for (OpFoldResult &ofr : values) {
if (ofr.is<Attribute>()) {
if (auto attr = dyn_cast<Attribute>(ofr)) {
// FIXME: We shouldn't need to do that, but right now, the static indices
// are created with the wrong type: `i64` instead of `index`.
// As a result, if we were to keep the attribute as is, we may fail to see
Expand All @@ -139,12 +139,11 @@ static void constifyIndexValues(
// The workaround here is to stick to the IndexAttr type for all the
// values, hence we recreate the attribute even when it is already static
// to make sure the type is consistent.
ofr = builder.getIndexAttr(
llvm::cast<IntegerAttr>(ofr.get<Attribute>()).getInt());
ofr = builder.getIndexAttr(llvm::cast<IntegerAttr>(attr).getInt());
continue;
}
std::optional<int64_t> maybeConstant =
getConstantIntValue(ofr.get<Value>());
getConstantIntValue(cast<Value>(ofr));
if (maybeConstant)
ofr = builder.getIndexAttr(*maybeConstant);
}
Expand Down Expand Up @@ -1406,12 +1405,11 @@ static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc,
// infinite loops in the driver.
if (result.use_empty() || maybeConstant == getAsOpFoldResult(result))
continue;
assert(maybeConstant.template is<Attribute>() &&
assert(isa<Attribute>(maybeConstant) &&
"The constified value should be either unchanged (i.e., == result) "
"or a constant");
Value constantVal = rewriter.create<arith::ConstantIndexOp>(
loc, llvm::cast<IntegerAttr>(maybeConstant.template get<Attribute>())
.getInt());
loc, llvm::cast<IntegerAttr>(cast<Attribute>(maybeConstant)).getInt());
for (Operation *op : llvm::make_early_inc_range(result.getUsers())) {
// modifyOpInPlace: lambda cannot capture structured bindings in C++17
// yet.
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
llvm::zip(op.getMixedOffsets(), sourceOp.getMixedOffsets(),
sourceOp.getMixedStrides(), op.getMixedSizes())) {
// We only support static sizes.
if (opSize.is<Value>()) {
if (isa<Value>(opSize)) {
return failure();
}
sizes.push_back(opSize);
Expand All @@ -109,7 +109,7 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
rewriter.getAffineConstantExpr(cast<IntegerAttr>(attr).getInt());
} else {
expr = rewriter.getAffineSymbolExpr(affineApplyOperands.size());
affineApplyOperands.push_back(sourceOffset.get<Value>());
affineApplyOperands.push_back(cast<Value>(sourceOffset));
}

// Multiply 'opOffset' by 'sourceStride' and make the 'expr' add the
Expand All @@ -121,7 +121,7 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
expr =
expr + rewriter.getAffineSymbolExpr(affineApplyOperands.size()) *
cast<IntegerAttr>(sourceStrideAttr).getInt();
affineApplyOperands.push_back(opOffset.get<Value>());
affineApplyOperands.push_back(cast<Value>(opOffset));
}

AffineMap map = AffineMap::get(0, affineApplyOperands.size(), expr);
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
AffineExpr s1 = builder.getAffineSymbolExpr(1);
for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) {
int64_t baseExpandedStride =
cast<IntegerAttr>(expandedStrides[doneStrideIdx].get<Attribute>())
cast<IntegerAttr>(cast<Attribute>(expandedStrides[doneStrideIdx]))
.getInt();
expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply(
builder, expandShape.getLoc(),
Expand All @@ -396,7 +396,7 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
AffineExpr s0 = builder.getAffineSymbolExpr(0);
for (; doneStrideIdx < groupSize; ++doneStrideIdx) {
int64_t baseExpandedStride =
cast<IntegerAttr>(expandedStrides[doneStrideIdx].get<Attribute>())
cast<IntegerAttr>(cast<Attribute>(expandedStrides[doneStrideIdx]))
.getInt();
expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply(
builder, expandShape.getLoc(), s0 * baseExpandedStride, {origStride});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ using namespace mlir::memref;
static FailureOr<OpFoldResult> makeIndependent(OpBuilder &b, Location loc,
OpFoldResult ofr,
ValueRange independencies) {
if (ofr.is<Attribute>())
if (isa<Attribute>(ofr))
return ofr;
AffineMap boundMap;
ValueDimList mapOperands;
Expand Down

0 comments on commit 30916b6

Please sign in to comment.