Skip to content

Commit

Permalink
[LLVMGPU] Fix lowering of functions that don't use all bindings (#19773)
Browse files Browse the repository at this point in the history
Runtime changes in [some previous PR, ask Ben] mean that now, GPU
kernels are passed one pointer for each binding in the pipeline layout,
whether or not they are used. When the previous behavior, which was to
only pass in the needed pointers one after another, was removed, the GPU
code was not updated to mach.

This PR updates the conversion from func.func to llvm.func to use one
pointer per binding. It also moves the setting of attributes like
noundef or nonnull into the function conversion, instead of making the
lowerigs for hal.interface.binding.subspan and
hal.interface.constant.load reduntantly add those attributes.
  • Loading branch information
krzysz00 authored Jan 24, 2025
1 parent 4215100 commit c52eb68
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 90 deletions.
195 changes: 106 additions & 89 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,27 +201,43 @@ class TestLLVMGPULegalizeOpPass final
}
};

/// Convention with the HAL side to pass kernel arguments.
/// The bindings are ordered based on binding set and binding index then
/// compressed and mapped to dense set of arguments.
/// This function looks at the symbols and return the mapping between
/// InterfaceBindingOp and kernel argument index.
/// For instance if the kernel has (set, bindings) A(0, 1), B(1, 5), C(0, 6) it
/// will return the mapping [A, 0], [C, 1], [B, 2]
static llvm::SmallDenseMap<APInt, size_t>
getKernelArgMapping(Operation *funcOp) {
llvm::SetVector<APInt> usedBindingSet;
funcOp->walk([&](IREE::HAL::InterfaceBindingSubspanOp subspanOp) {
usedBindingSet.insert(subspanOp.getBinding());
});
auto sparseBindings = usedBindingSet.takeVector();
std::sort(sparseBindings.begin(), sparseBindings.end(),
[](APInt lhs, APInt rhs) { return lhs.ult(rhs); });
llvm::SmallDenseMap<APInt, size_t> mapBindingArgIndex;
for (auto [index, binding] : llvm::enumerate(sparseBindings)) {
mapBindingArgIndex[binding] = index;
namespace {
/// A package for the results of `analyzeSubspanOps` to avoid
/// arbitrary tuples. The default values are the results for an unused
/// binding, which is read-only, unused, and in address space 0.
struct BindingProperties {
bool readonly = true;
bool unused = true;
unsigned addressSpace = 0;
};
} // namespace
/// Analyze subspan binding ops to recover properties of the binding, such as
/// if it is read-only and the address space it lives in.
static FailureOr<SmallVector<BindingProperties>>
analyzeSubspans(llvm::SetVector<IREE::HAL::InterfaceBindingSubspanOp> &subspans,
int64_t numBindings, const LLVMTypeConverter *typeConverter) {
SmallVector<BindingProperties> result(numBindings, BindingProperties{});
for (auto subspan : subspans) {
int64_t binding = subspan.getBinding().getSExtValue();
result[binding].unused = false;
result[binding].readonly &= IREE::HAL::bitEnumContainsAny(
subspan.getDescriptorFlags().value_or(IREE::HAL::DescriptorFlags::None),
IREE::HAL::DescriptorFlags::ReadOnly);
unsigned bindingAddrSpace = 0;
auto bindingType = dyn_cast<BaseMemRefType>(subspan.getType());
if (bindingType) {
bindingAddrSpace = *typeConverter->getMemRefAddressSpace(bindingType);
}
if (result[binding].addressSpace != 0 &&
result[binding].addressSpace != bindingAddrSpace) {
return subspan.emitOpError("address space for this op (" +
Twine(bindingAddrSpace) +
") doesn't match previously found space (" +
Twine(result[binding].addressSpace) + ")");
}
result[binding].addressSpace = bindingAddrSpace;
}
return mapBindingArgIndex;
return result;
}

class ConvertFunc : public ConvertToLLVMPattern {
Expand All @@ -242,30 +258,46 @@ class ConvertFunc : public ConvertToLLVMPattern {
assert(fnType.getNumInputs() == 0 && fnType.getNumResults() == 0);

TypeConverter::SignatureConversion signatureConverter(/*numOrigInputs=*/0);
auto argMapping = getKernelArgMapping(funcOp);
// There may be dead symbols, we pick i32 pointer as default argument type.
SmallVector<Type, 8> llvmInputTypes(
argMapping.size(), LLVM::LLVMPointerType::get(rewriter.getContext()));
// Note: we assume that the pipeline layout is the same for all bindings
// in this function.
IREE::HAL::PipelineLayoutAttr layout;
llvm::SetVector<IREE::HAL::InterfaceBindingSubspanOp> subspans;
funcOp.walk([&](IREE::HAL::InterfaceBindingSubspanOp subspanOp) {
LLVM::LLVMPointerType llvmType;
if (auto memrefType = dyn_cast<BaseMemRefType>(subspanOp.getType())) {
unsigned addrSpace =
*getTypeConverter()->getMemRefAddressSpace(memrefType);
llvmType = LLVM::LLVMPointerType::get(rewriter.getContext(), addrSpace);
} else {
llvmType = LLVM::LLVMPointerType::get(rewriter.getContext());
if (!layout) {
layout = subspanOp.getLayout();
}
llvmInputTypes[argMapping[subspanOp.getBinding()]] = llvmType;
subspans.insert(subspanOp);
});
// As a convention with HAL, push constants are appended as kernel arguments
// after all the binding inputs.
uint64_t numConstants = 0;
funcOp.walk([&](IREE::HAL::InterfaceConstantLoadOp constantOp) {
numConstants =
std::max(constantOp.getOrdinal().getZExtValue() + 1, numConstants);

funcOp.walk([&](IREE::HAL::InterfaceConstantLoadOp constOp) {
if (!layout) {
layout = constOp.getLayout();
}
return WalkResult::interrupt();
});
llvmInputTypes.resize(argMapping.size() + numConstants,
rewriter.getI32Type());

int64_t numBindings = 0;
int64_t numConstants = 0;
if (layout) {
numConstants = layout.getConstants();
numBindings = layout.getBindings().size();
}

FailureOr<SmallVector<BindingProperties>> maybeBindingsInfo =
analyzeSubspans(subspans, numBindings, getTypeConverter());
if (failed(maybeBindingsInfo))
return failure();
auto bindingsInfo = std::move(*maybeBindingsInfo);

SmallVector<Type, 8> llvmInputTypes;
llvmInputTypes.reserve(numBindings + numConstants);
for (const auto &info : bindingsInfo) {
llvmInputTypes.push_back(
LLVM::LLVMPointerType::get(rewriter.getContext(), info.addressSpace));
}
// All the push constants are i32 and go at the end of the argument list.
llvmInputTypes.resize(numBindings + numConstants, rewriter.getI32Type());

if (!llvmInputTypes.empty())
signatureConverter.addInputs(llvmInputTypes);

Expand Down Expand Up @@ -296,6 +328,37 @@ class ConvertFunc : public ConvertToLLVMPattern {
return failure();
}

// Set argument attributes.
Attribute unit = rewriter.getUnitAttr();
for (auto [idx, info] : llvm::enumerate(bindingsInfo)) {
// As a convention with HAL all the kernel argument pointers are 16Bytes
// aligned.
newFuncOp.setArgAttr(idx, LLVM::LLVMDialect::getAlignAttrName(),
rewriter.getI32IntegerAttr(16));
// It is safe to set the noalias attribute as it is guaranteed that the
// ranges within bindings won't alias.
newFuncOp.setArgAttr(idx, LLVM::LLVMDialect::getNoAliasAttrName(), unit);
newFuncOp.setArgAttr(idx, LLVM::LLVMDialect::getNonNullAttrName(), unit);
newFuncOp.setArgAttr(idx, LLVM::LLVMDialect::getNoUndefAttrName(), unit);
if (info.unused) {
// While LLVM can work this out from the lack of use, we might as well
// be explicit here just to be safe.
newFuncOp.setArgAttr(idx, LLVM::LLVMDialect::getReadnoneAttrName(),
unit);
} else if (info.readonly) {
// Setting the readonly attribute here will generate non-coherent cache
// loads.
newFuncOp.setArgAttr(idx, LLVM::LLVMDialect::getReadonlyAttrName(),
unit);
}
}
for (int64_t i = 0; i < numConstants; ++i) {
// Push constants are never `undef`, annotate that here, just as with
// bindings.
newFuncOp.setArgAttr(numBindings + i,
LLVM::LLVMDialect::getNoUndefAttrName(), unit);
}

rewriter.eraseOp(funcOp);
return success();
}
Expand All @@ -309,25 +372,6 @@ class ConvertIREEBindingSubspanOp : public ConvertToLLVMPattern {
IREE::HAL::InterfaceBindingSubspanOp::getOperationName(), context,
converter) {}

/// Checks all subspanOps with the same binding has readonly attribute
static bool checkAllSubspansReadonly(LLVM::LLVMFuncOp llvmFuncOp,
APInt binding) {
bool allReadOnly = false;
llvmFuncOp.walk([&](IREE::HAL::InterfaceBindingSubspanOp op) {
if (op.getBinding() == binding) {
if (!bitEnumContainsAny(op.getDescriptorFlags().value_or(
IREE::HAL::DescriptorFlags::None),
IREE::HAL::DescriptorFlags::ReadOnly)) {
allReadOnly = false;
return WalkResult::interrupt();
}
allReadOnly = true;
}
return WalkResult::advance();
});
return allReadOnly;
}

LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Expand All @@ -337,35 +381,14 @@ class ConvertIREEBindingSubspanOp : public ConvertToLLVMPattern {
return failure();
assert(llvmFuncOp.getNumArguments() > 0);

auto argMapping = getKernelArgMapping(llvmFuncOp);
Location loc = op->getLoc();
auto subspanOp = cast<IREE::HAL::InterfaceBindingSubspanOp>(op);
IREE::HAL::InterfaceBindingSubspanOpAdaptor adaptor(
operands, op->getAttrDictionary());
MemRefType memrefType =
llvm::dyn_cast<MemRefType>(subspanOp.getResult().getType());
mlir::BlockArgument llvmBufferArg =
llvmFuncOp.getArgument(argMapping[subspanOp.getBinding()]);
// As a convention with HAL all the kernel argument pointers are 16Bytes
// aligned.
llvmFuncOp.setArgAttr(llvmBufferArg.getArgNumber(),
LLVM::LLVMDialect::getAlignAttrName(),
rewriter.getI32IntegerAttr(16));
// It is safe to set the noalias attribute as it is guaranteed that the
// ranges within bindings won't alias.
Attribute unit = rewriter.getUnitAttr();
llvmFuncOp.setArgAttr(llvmBufferArg.getArgNumber(),
LLVM::LLVMDialect::getNoAliasAttrName(), unit);
llvmFuncOp.setArgAttr(llvmBufferArg.getArgNumber(),
LLVM::LLVMDialect::getNonNullAttrName(), unit);
llvmFuncOp.setArgAttr(llvmBufferArg.getArgNumber(),
LLVM::LLVMDialect::getNoUndefAttrName(), unit);
if (checkAllSubspansReadonly(llvmFuncOp, subspanOp.getBinding())) {
// Setting the readonly attribute here will generate non-coherent cache
// loads.
llvmFuncOp.setArgAttr(llvmBufferArg.getArgNumber(),
LLVM::LLVMDialect::getReadonlyAttrName(), unit);
}
llvmFuncOp.getArgument(subspanOp.getBinding().getZExtValue());
// Add the byte offset.
Value llvmBufferBasePtr = llvmBufferArg;

Expand Down Expand Up @@ -468,18 +491,12 @@ class ConvertIREEConstantOp : public ConvertToLLVMPattern {
return failure();
assert(llvmFuncOp.getNumArguments() > 0);

auto argMapping = getKernelArgMapping(llvmFuncOp);
auto ireeConstantOp = cast<IREE::HAL::InterfaceConstantLoadOp>(op);
size_t numBindings = ireeConstantOp.getLayout().getBindings().size();
mlir::BlockArgument llvmBufferArg = llvmFuncOp.getArgument(
argMapping.size() + ireeConstantOp.getOrdinal().getZExtValue());
numBindings + ireeConstantOp.getOrdinal().getZExtValue());
assert(llvmBufferArg.getType().isInteger(32));

// Push constants are never `undef`, annotate that here, just as with
// bindings.
llvmFuncOp.setArgAttr(llvmBufferArg.getArgNumber(),
LLVM::LLVMDialect::getNoUndefAttrName(),
rewriter.getUnitAttr());

Type dstType = getTypeConverter()->convertType(ireeConstantOp.getType());
// llvm.zext requires that the result type has a larger bitwidth.
if (dstType == llvmBufferArg.getType()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ builtin.module {
// INDEX32-LABEL: llvm.func @abs_ex_dispatch_0
// CHECK-SAME: (%{{[a-zA-Z0-9]*}}: !llvm.ptr {llvm.align = 16 : i32, llvm.noalias, llvm.nonnull, llvm.noundef, llvm.readonly},
// CHECK-SAME: %{{[a-zA-Z0-9]*}}: !llvm.ptr {llvm.align = 16 : i32, llvm.noalias, llvm.nonnull, llvm.noundef},
// CHECK-SAME: %{{[a-zA-Z0-9]*}}: !llvm.ptr {llvm.align = 16 : i32, llvm.noalias, llvm.nonnull, llvm.noundef})
// CHECK-SAME: %{{[a-zA-Z0-9]*}}: !llvm.ptr {llvm.align = 16 : i32, llvm.noalias, llvm.nonnull, llvm.noundef},
// CHECK-SAME: %{{[a-zA-Z0-9]*}}: !llvm.ptr {llvm.align = 16 : i32, llvm.noalias, llvm.nonnull, llvm.noundef, llvm.readnone})
// CHECK: rocdl.workgroup.dim.x
// CHECK: llvm.getelementptr %{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
// INDEX32: llvm.getelementptr %{{.*}} : (!llvm.ptr, i32) -> !llvm.ptr, f32
Expand Down Expand Up @@ -230,3 +231,31 @@ module {
}
// CHECK-LABEL: llvm.func @emulation_lowering(
// CHECK-NOT: builtin.unrealized_conversion_cast

// -----
// Test that an unused binding still appears in the kernargs
#pipeline_layout = #hal.pipeline.layout<constants = 1, bindings = [
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>
]>
builtin.module {
func.func @missing_ptr_dispatch_copy_idx_0() {
%c0 = arith.constant 0 : index
%0 = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : i32
%1 = arith.index_castui %0 : i32 to index
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) offset(%1) flags(ReadOnly) : memref<16xf32, strided<[1], offset : ?>, #gpu.address_space<global>>
%3 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) : memref<16xf32, #gpu.address_space<global>>
%4 = memref.load %2[%c0] : memref<16xf32, strided<[1], offset : ?>, #gpu.address_space<global>>
memref.store %4, %3[%c0] : memref<16xf32, #gpu.address_space<global>>
return
}
}
// CHECK-LABEL: llvm.func @missing_ptr_dispatch_copy_idx_0
// CHECK-SAME: (%[[arg0:.+]]: !llvm.ptr<1> {llvm.align = 16 : i32, llvm.noalias, llvm.nonnull, llvm.noundef, llvm.readonly},
// CHECK-SAME: %[[arg1:.+]]: !llvm.ptr {llvm.align = 16 : i32, llvm.noalias, llvm.nonnull, llvm.noundef, llvm.readnone},
// CHECK-SAME: %[[arg2:.+]]: !llvm.ptr<1> {llvm.align = 16 : i32, llvm.noalias, llvm.nonnull, llvm.noundef},
// CHECK-SAME: %[[arg3:.+]]: i32 {llvm.noundef})
// CHECK: llvm.zext %[[arg3]] : i32 to i64
// CHECK: llvm.insertvalue %[[arg0]]
// CHECK: llvm.insertvalue %[[arg2]]

0 comments on commit c52eb68

Please sign in to comment.