Skip to content

Commit

Permalink
[mlir][LLVM] Switch undef for poison for uninitialized values (ll…
Browse files Browse the repository at this point in the history
…vm#125629)

LLVM itself is generally moving away from using `undef` and towards
using `poison`, to the point of having a lint that caches new uses of
`undef` in tests.

In order to not trip the lint on new patterns and to conform to the
evolution of LLVM
- Rename valious ::undef() methods on StructBuilder subclasses to
::poison()
- Audit the uses of UndefOp in the MLIR libraries and replace almost all
of them with PoisonOp

The remaining uses of `undef` are initializing `uninitialized` memrefs,
explicit conversions to undef from SPIR-V, and a few cases in
AMDGPUToROCDL where usage like

    %v = insertelement <M x iN> undef, iN %v, i32 0
    %arg = bitcast <M x iN> %v to i(M * N)

is used to handle "i32" arguments that are are really packed vectors of
smaller types that won't always be fully initialized.
  • Loading branch information
krzysz00 authored Feb 6, 2025
1 parent e41ffd3 commit f4e3b87
Show file tree
Hide file tree
Showing 40 changed files with 457 additions and 455 deletions.
4 changes: 2 additions & 2 deletions mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ class ComplexStructBuilder : public StructBuilder {
/// Construct a helper for the given complex number value.
using StructBuilder::StructBuilder;
/// Build IR creating an `undef` value of the complex number type.
static ComplexStructBuilder undef(OpBuilder &builder, Location loc,
Type type);
static ComplexStructBuilder poison(OpBuilder &builder, Location loc,
Type type);

// Build IR extracting the real value from the complex number struct.
Value real(OpBuilder &builder, Location loc);
Expand Down
10 changes: 5 additions & 5 deletions mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ class MemRefDescriptor : public StructBuilder {
public:
/// Construct a helper for the given descriptor value.
explicit MemRefDescriptor(Value descriptor);
/// Builds IR creating an `undef` value of the descriptor type.
static MemRefDescriptor undef(OpBuilder &builder, Location loc,
Type descriptorType);
/// Builds IR creating a `poison` value of the descriptor type.
static MemRefDescriptor poison(OpBuilder &builder, Location loc,
Type descriptorType);
/// Builds IR creating a MemRef descriptor that represents `type` and
/// populates it with static shape and stride information extracted from the
/// type.
Expand Down Expand Up @@ -160,8 +160,8 @@ class UnrankedMemRefDescriptor : public StructBuilder {
/// Construct a helper for the given descriptor value.
explicit UnrankedMemRefDescriptor(Value descriptor);
/// Builds IR creating an `undef` value of the descriptor type.
static UnrankedMemRefDescriptor undef(OpBuilder &builder, Location loc,
Type descriptorType);
static UnrankedMemRefDescriptor poison(OpBuilder &builder, Location loc,
Type descriptorType);

/// Builds IR extracting the rank from the descriptor
Value rank(OpBuilder &builder, Location loc) const;
Expand Down
6 changes: 3 additions & 3 deletions mlir/include/mlir/Conversion/LLVMCommon/StructBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ class StructBuilder {
public:
/// Construct a helper for the given value.
explicit StructBuilder(Value v);
/// Builds IR creating an `undef` value of the descriptor type.
static StructBuilder undef(OpBuilder &builder, Location loc,
Type descriptorType);
/// Builds IR creating a `poison` value of the descriptor type.
static StructBuilder poison(OpBuilder &builder, Location loc,
Type descriptorType);

/*implicit*/ operator Value() { return value; }

Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
auto allTruePredicate = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(predicateType, true));
// Create padding vector (never used due to all-true predicate).
auto padVector = rewriter.create<LLVM::UndefOp>(loc, sliceType);
auto padVector = rewriter.create<LLVM::PoisonOp>(loc, sliceType);
// Get a pointer to the current slice.
auto slicePtr =
getInMemoryTileSlicePtr(rewriter, loc, tileAlloca, sliceIndex);
Expand Down
17 changes: 9 additions & 8 deletions mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ using namespace mlir::arith;
static constexpr unsigned kRealPosInComplexNumberStruct = 0;
static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1;

ComplexStructBuilder ComplexStructBuilder::undef(OpBuilder &builder,
Location loc, Type type) {
Value val = builder.create<LLVM::UndefOp>(loc, type);
ComplexStructBuilder ComplexStructBuilder::poison(OpBuilder &builder,
Location loc, Type type) {
Value val = builder.create<LLVM::PoisonOp>(loc, type);
return ComplexStructBuilder(val);
}

Expand Down Expand Up @@ -109,7 +109,8 @@ struct CreateOpConversion : public ConvertOpToLLVMPattern<complex::CreateOp> {
// Pack real and imaginary part in a complex number struct.
auto loc = complexOp.getLoc();
auto structType = typeConverter->convertType(complexOp.getType());
auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType);
auto complexStruct =
ComplexStructBuilder::poison(rewriter, loc, structType);
complexStruct.setReal(rewriter, loc, adaptor.getReal());
complexStruct.setImaginary(rewriter, loc, adaptor.getImaginary());

Expand Down Expand Up @@ -183,7 +184,7 @@ struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> {

// Initialize complex number struct for result.
auto structType = typeConverter->convertType(op.getType());
auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
auto result = ComplexStructBuilder::poison(rewriter, loc, structType);

// Emit IR to add complex numbers.
arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
Expand Down Expand Up @@ -214,7 +215,7 @@ struct DivOpConversion : public ConvertOpToLLVMPattern<complex::DivOp> {

// Initialize complex number struct for result.
auto structType = typeConverter->convertType(op.getType());
auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
auto result = ComplexStructBuilder::poison(rewriter, loc, structType);

// Emit IR to add complex numbers.
arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
Expand Down Expand Up @@ -262,7 +263,7 @@ struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> {

// Initialize complex number struct for result.
auto structType = typeConverter->convertType(op.getType());
auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
auto result = ComplexStructBuilder::poison(rewriter, loc, structType);

// Emit IR to add complex numbers.
arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
Expand Down Expand Up @@ -302,7 +303,7 @@ struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> {

// Initialize complex number struct for result.
auto structType = typeConverter->convertType(op.getType());
auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
auto result = ComplexStructBuilder::poison(rewriter, loc, structType);

// Emit IR to substract complex numbers.
arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,7 @@ struct UnrealizedConversionCastOpLowering
// `ReturnOp` interacts with the function signature and must have as many
// operands as the function has return values. Because in LLVM IR, functions
// can only return 0 or 1 value, we pack multiple values into a structure type.
// Emit `UndefOp` followed by `InsertValueOp`s to create such structure if
// Emit `PoisonOp` followed by `InsertValueOp`s to create such structure if
// necessary before returning it
struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
using ConvertOpToLLVMPattern<func::ReturnOp>::ConvertOpToLLVMPattern;
Expand Down Expand Up @@ -714,7 +714,7 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
return rewriter.notifyMatchFailure(op, "could not convert result types");
}

Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
Value packed = rewriter.create<LLVM::PoisonOp>(loc, packedType);
for (auto [idx, operand] : llvm::enumerate(updatedOperands)) {
packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, operand, idx);
}
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
return rewriter.notifyMatchFailure(op, "expected vector result");

Location loc = op->getLoc();
Value result = rewriter.create<LLVM::UndefOp>(loc, vectorType);
Value result = rewriter.create<LLVM::PoisonOp>(loc, vectorType);
Type indexType = converter.convertType(rewriter.getIndexType());
StringAttr name = op->getName().getIdentifier();
Type elementType = vectorType.getElementType();
Expand Down Expand Up @@ -771,7 +771,7 @@ LogicalResult GPUReturnOpLowering::matchAndRewrite(
return rewriter.notifyMatchFailure(op, "could not convert result types");
}

Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
Value packed = rewriter.create<LLVM::PoisonOp>(loc, packedType);
for (auto [idx, operand] : llvm::enumerate(updatedOperands)) {
packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, operand, idx);
}
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ struct WmmaConstantOpToNVVMLowering
cast<gpu::MMAMatrixType>(subgroupMmaConstantOp.getType()));
// If the element type is a vector create a vector from the operand.
if (auto vecType = dyn_cast<VectorType>(type.getBody()[0])) {
Value vecCst = rewriter.create<LLVM::UndefOp>(loc, vecType);
Value vecCst = rewriter.create<LLVM::PoisonOp>(loc, vecType);
for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) {
Value idx = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(), vecEl);
Expand All @@ -288,7 +288,7 @@ struct WmmaConstantOpToNVVMLowering
}
cst = vecCst;
}
Value matrixStruct = rewriter.create<LLVM::UndefOp>(loc, type);
Value matrixStruct = rewriter.create<LLVM::PoisonOp>(loc, type);
for (size_t i : llvm::seq(size_t(0), type.getBody().size())) {
matrixStruct =
rewriter.create<LLVM::InsertValueOp>(loc, matrixStruct, cst, i);
Expand Down Expand Up @@ -355,7 +355,7 @@ struct WmmaElementwiseOpToNVVMLowering
size_t numOperands = adaptor.getOperands().size();
LLVM::LLVMStructType destType = convertMMAToLLVMType(
cast<gpu::MMAMatrixType>(subgroupMmaElementwiseOp.getType()));
Value matrixStruct = rewriter.create<LLVM::UndefOp>(loc, destType);
Value matrixStruct = rewriter.create<LLVM::PoisonOp>(loc, destType);
for (size_t i = 0, e = destType.getBody().size(); i < e; ++i) {
SmallVector<Value> extractedOperands;
for (size_t opIdx = 0; opIdx < numOperands; opIdx++) {
Expand Down
20 changes: 10 additions & 10 deletions mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ MemRefDescriptor::MemRefDescriptor(Value descriptor)
}

/// Builds IR creating an `undef` value of the descriptor type.
MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc,
Type descriptorType) {
MemRefDescriptor MemRefDescriptor::poison(OpBuilder &builder, Location loc,
Type descriptorType) {

Value descriptor = builder.create<LLVM::UndefOp>(loc, descriptorType);
Value descriptor = builder.create<LLVM::PoisonOp>(loc, descriptorType);
return MemRefDescriptor(descriptor);
}

Expand Down Expand Up @@ -60,7 +60,7 @@ MemRefDescriptor MemRefDescriptor::fromStaticShape(
auto convertedType = typeConverter.convertType(type);
assert(convertedType && "unexpected failure in memref type conversion");

auto descr = MemRefDescriptor::undef(builder, loc, convertedType);
auto descr = MemRefDescriptor::poison(builder, loc, convertedType);
descr.setAllocatedPtr(builder, loc, memory);
descr.setAlignedPtr(builder, loc, alignedMemory);
descr.setConstantOffset(builder, loc, offset);
Expand Down Expand Up @@ -224,7 +224,7 @@ Value MemRefDescriptor::pack(OpBuilder &builder, Location loc,
const LLVMTypeConverter &converter,
MemRefType type, ValueRange values) {
Type llvmType = converter.convertType(type);
auto d = MemRefDescriptor::undef(builder, loc, llvmType);
auto d = MemRefDescriptor::poison(builder, loc, llvmType);

d.setAllocatedPtr(builder, loc, values[kAllocatedPtrPosInMemRefDescriptor]);
d.setAlignedPtr(builder, loc, values[kAlignedPtrPosInMemRefDescriptor]);
Expand Down Expand Up @@ -300,10 +300,10 @@ UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(Value descriptor)
: StructBuilder(descriptor) {}

/// Builds IR creating an `undef` value of the descriptor type.
UnrankedMemRefDescriptor UnrankedMemRefDescriptor::undef(OpBuilder &builder,
Location loc,
Type descriptorType) {
Value descriptor = builder.create<LLVM::UndefOp>(loc, descriptorType);
UnrankedMemRefDescriptor UnrankedMemRefDescriptor::poison(OpBuilder &builder,
Location loc,
Type descriptorType) {
Value descriptor = builder.create<LLVM::PoisonOp>(loc, descriptorType);
return UnrankedMemRefDescriptor(descriptor);
}
Value UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) const {
Expand Down Expand Up @@ -331,7 +331,7 @@ Value UnrankedMemRefDescriptor::pack(OpBuilder &builder, Location loc,
UnrankedMemRefType type,
ValueRange values) {
Type llvmType = converter.convertType(type);
auto d = UnrankedMemRefDescriptor::undef(builder, loc, llvmType);
auto d = UnrankedMemRefDescriptor::poison(builder, loc, llvmType);

d.setRank(builder, loc, values[kRankInUnrankedMemRefDescriptor]);
d.setMemRefDescPtr(builder, loc, values[kPtrInUnrankedMemRefDescriptor]);
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Conversion/LLVMCommon/Pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor(
ArrayRef<Value> sizes, ArrayRef<Value> strides,
ConversionPatternRewriter &rewriter) const {
auto structType = typeConverter->convertType(memRefType);
auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType);
auto memRefDescriptor = MemRefDescriptor::poison(rewriter, loc, structType);

// Field 1: Allocated pointer, used for malloc/free.
memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedPtr);
Expand Down Expand Up @@ -319,7 +319,7 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
if (!descriptorType)
return failure();
auto updatedDesc =
UnrankedMemRefDescriptor::undef(builder, loc, descriptorType);
UnrankedMemRefDescriptor::poison(builder, loc, descriptorType);
Value rank = desc.rank(builder, loc);
updatedDesc.setRank(builder, loc, rank);
updatedDesc.setMemRefDescPtr(builder, loc, memory);
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors(
auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy;
auto resultNDVectoryTy = resultTypeInfo.llvmNDVectorTy;
auto loc = op->getLoc();
Value desc = rewriter.create<LLVM::UndefOp>(loc, resultNDVectoryTy);
Value desc = rewriter.create<LLVM::PoisonOp>(loc, resultNDVectoryTy);
nDVectorIterate(resultTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
// For this unrolled `position` corresponding to the `linearIndex`^th
// element, extract operand vectors
Expand Down
18 changes: 9 additions & 9 deletions mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -714,10 +714,10 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
// rank = ConstantOp srcRank
auto rankVal = rewriter.create<LLVM::ConstantOp>(
loc, getIndexType(), rewriter.getIndexAttr(rank));
// undef = UndefOp
// poison = PoisonOp
UnrankedMemRefDescriptor memRefDesc =
UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType);
// d1 = InsertValueOp undef, rank, 0
UnrankedMemRefDescriptor::poison(rewriter, loc, targetStructType);
// d1 = InsertValueOp poison, rank, 0
memRefDesc.setRank(rewriter, loc, rankVal);
// d2 = InsertValueOp d1, ptr, 1
memRefDesc.setMemRefDescPtr(rewriter, loc, ptr);
Expand Down Expand Up @@ -928,7 +928,7 @@ struct MemorySpaceCastOpLowering
Value sourceUnderlyingDesc = sourceDesc.memRefDescPtr(rewriter, loc);

// Create and allocate storage for new memref descriptor.
auto result = UnrankedMemRefDescriptor::undef(
auto result = UnrankedMemRefDescriptor::poison(
rewriter, loc, typeConverter->convertType(resultTypeU));
result.setRank(rewriter, loc, rank);
SmallVector<Value, 1> sizes;
Expand Down Expand Up @@ -1058,7 +1058,7 @@ struct MemRefReinterpretCastOpLowering

// Create descriptor.
Location loc = castOp.getLoc();
auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
auto desc = MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy);

// Set allocated and aligned pointers.
Value allocatedPtr, alignedPtr;
Expand Down Expand Up @@ -1128,7 +1128,7 @@ struct MemRefReshapeOpLowering
// Create descriptor.
Location loc = reshapeOp.getLoc();
auto desc =
MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy);

// Set allocated and aligned pointers.
Value allocatedPtr, alignedPtr;
Expand Down Expand Up @@ -1210,7 +1210,7 @@ struct MemRefReshapeOpLowering

// Create the unranked memref descriptor that holds the ranked one. The
// inner descriptor is allocated on stack.
auto targetDesc = UnrankedMemRefDescriptor::undef(
auto targetDesc = UnrankedMemRefDescriptor::poison(
rewriter, loc, typeConverter->convertType(targetType));
targetDesc.setRank(rewriter, loc, resultRank);
SmallVector<Value, 4> sizes;
Expand Down Expand Up @@ -1366,7 +1366,7 @@ class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
if (transposeOp.getPermutation().isIdentity())
return rewriter.replaceOp(transposeOp, {viewMemRef}), success();

auto targetMemRef = MemRefDescriptor::undef(
auto targetMemRef = MemRefDescriptor::poison(
rewriter, loc,
typeConverter->convertType(transposeOp.getIn().getType()));

Expand Down Expand Up @@ -1469,7 +1469,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {

// Create the descriptor.
MemRefDescriptor sourceMemRef(adaptor.getSource());
auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
auto targetMemRef = MemRefDescriptor::poison(rewriter, loc, targetDescTy);

// Field 1: Copy the allocated pointer, used for malloc/free.
Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
Expand Down
Loading

0 comments on commit f4e3b87

Please sign in to comment.