Skip to content

Commit

Permalink
add flag: --iree-llvmcpu-link-ukernel-bitcode, default true (iree…
Browse files Browse the repository at this point in the history
…-org#15354)

@lundong reports that linking to ukernels code as a system plugin (.so)
is broken since we added ukernels bitcode, relying on overridable
generic weak symbols being linked last, so that the non-weak optimized
implementations actually override them. There's no easy fix and no great
alternative to weak symbols for what we're doing here, but this flag
should unblock this use case...
  • Loading branch information
bjacob authored Nov 10, 2023
1 parent 2f9a1e1 commit fc44185
Showing 1 changed file with 119 additions and 110 deletions.
229 changes: 119 additions & 110 deletions compiler/src/iree/compiler/Dialect/HAL/Target/LLVMCPU/LLVMCPUTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ static llvm::cl::opt<bool> clEnableCPUMicrokernels(
"Enables microkernel lowering for llvmcpu backend (experimental)"),
llvm::cl::init(false));

static llvm::cl::opt<bool> clLinkCPUUKernelBitcode(
"iree-llvmcpu-link-ukernel-bitcode",
llvm::cl::desc("Link ukernel bitcode libraries into generated executables"),
llvm::cl::init(true));

static llvm::cl::opt<unsigned> clNativeVectorWidthInBytes(
"iree-llvmcpu-native-vector-width-in-bytes",
llvm::cl::desc("sets the native vector register width of the hardware. It "
Expand Down Expand Up @@ -493,130 +498,134 @@ class LLVMCPUTargetBackend final : public TargetBackend {
}
}

// Tracks ukernel functions, in order to set their linkage to internal
// after ukernel bitcode modules are linked but before runLLVMIRPasses, so
// that unused ukernel code paths get DCE'd. Notes:
// 1. We can't rely on fixupVisibility to do this, because fixupVisibility
// is called after runLLVMIRPasses, which is what performs DCE. The
// reason why fixupVisibility can't be moved before runLLVMIRPasses is
// that causes all math functions to be DCE'd, as references to them get
// introduced only later down. The basic difference here between ukernel
// functions and math functions is that any references to ukernel
// functions already exist at this point.
// 2. We can't just set internal linkage right away upon loading ukernel
// bitcode modules, because some ukernel symbols have to override weak
// symbols, and that's disabled when linkage is set to internal.
std::unordered_set<std::string> ukernelFunctions;

// Link in ukernel bitcode.
if (hasMicrokernel(variantOp)) {
auto setAlwaysInline = [&](llvm::Module &module) {
for (auto &func : module.getFunctionList()) {
func.addFnAttr(llvm::Attribute::AlwaysInline);
}
};
auto addUkernelFunctions = [&](const llvm::Module &module) {
for (auto &func : module.getFunctionList()) {
if (func.isDeclaration()) {
continue;
if (clLinkCPUUKernelBitcode) {
// Tracks ukernel functions, in order to set their linkage to internal
// after ukernel bitcode modules are linked but before runLLVMIRPasses, so
// that unused ukernel code paths get DCE'd. Notes:
// 1. We can't rely on fixupVisibility to do this, because fixupVisibility
// is called after runLLVMIRPasses, which is what performs DCE. The
// reason why fixupVisibility can't be moved before runLLVMIRPasses is
// that causes all math functions to be DCE'd, as references to them
// get introduced only later down. The basic difference here between
// ukernel functions and math functions is that any references to
// ukernel functions already exist at this point.
// 2. We can't just set internal linkage right away upon loading ukernel
// bitcode modules, because some ukernel symbols have to override weak
// symbols, and that's disabled when linkage is set to internal.
std::unordered_set<std::string> ukernelFunctions;

// Link in ukernel bitcode.
if (hasMicrokernel(variantOp)) {
auto setAlwaysInline = [&](llvm::Module &module) {
for (auto &func : module.getFunctionList()) {
func.addFnAttr(llvm::Attribute::AlwaysInline);
}
ukernelFunctions.insert(func.getName().str());
};
auto addUkernelFunctions = [&](const llvm::Module &module) {
for (auto &func : module.getFunctionList()) {
if (func.isDeclaration()) {
continue;
}
ukernelFunctions.insert(func.getName().str());
}
};

llvm::Expected<std::unique_ptr<llvm::Module>> archBitcode =
loadUKernelArchBitcode(targetMachine.get(), context);
if (!archBitcode) {
return mlir::emitError(variantOp.getLoc())
<< "failed to load architecture-specific ukernel bitcode: "
<< llvm::toString(archBitcode.takeError());
}
};

llvm::Expected<std::unique_ptr<llvm::Module>> archBitcode =
loadUKernelArchBitcode(targetMachine.get(), context);
if (!archBitcode) {
return mlir::emitError(variantOp.getLoc())
<< "failed to load architecture-specific ukernel bitcode: "
<< llvm::toString(archBitcode.takeError());
}
llvm::Expected<std::unique_ptr<llvm::Module>> archEntryPointsBitcode =
loadUKernelArchEntryPointsBitcode(targetMachine.get(), context);
if (!archEntryPointsBitcode) {
return mlir::emitError(variantOp.getLoc())
<< "failed to load architecture-specific ukernel entry points "
"bitcode: "
<< llvm::toString(archEntryPointsBitcode.takeError());
}

llvm::Expected<std::unique_ptr<llvm::Module>> archEntryPointsBitcode =
loadUKernelArchEntryPointsBitcode(targetMachine.get(), context);
if (!archEntryPointsBitcode) {
return mlir::emitError(variantOp.getLoc())
<< "failed to load architecture-specific ukernel entry points "
"bitcode: "
<< llvm::toString(archEntryPointsBitcode.takeError());
}
// archBitcode and archEntryPointsBitcode are optional, may be null if
// there is none for the target architecture. However, they should
// simultaneously be null or non-null.
if ((archBitcode.get() == nullptr) !=
(archEntryPointsBitcode.get() == nullptr)) {
return mlir::emitError(variantOp.getLoc())
<< "there should be architecture-specific ukernel bit code "
"if, "
"and only if there is architecture-specific ukernels entry "
"points bitcode.";
}

// archBitcode and archEntryPointsBitcode are optional, may be null if
// there is none for the target architecture. However, they should
// simultaneously be null or non-null.
if ((archBitcode.get() == nullptr) !=
(archEntryPointsBitcode.get() == nullptr)) {
return mlir::emitError(variantOp.getLoc())
<< "there should be architecture-specific ukernel bit code if, "
"and only if there is architecture-specific ukernels entry "
"points bitcode.";
}
if (archBitcode.get()) {
addUkernelFunctions(*archBitcode.get());
addUkernelFunctions(*archEntryPointsBitcode.get());

// archEntryPointsBitcode contains overrides for weak symbols that
// will come in the baseBitcode below. So we link it before
// baseBitcode, with OverrideFromSrc.
StringRef archEntryPointsBitcodeName =
archEntryPointsBitcode.get()->getName();
if (failed(linkBitcodeModule(
variantOp.getLoc(), moduleLinker, 0, *targetMachine,
archEntryPointsBitcodeName, std::move(archEntryPointsBitcode),
setAlwaysInline))) {
return mlir::emitError(variantOp.getLoc())
<< "failed linking in architecture-specific ukernel entry "
"points bitcode "
"for target triple '"
<< targetTriple.str() << "'";
}

if (archBitcode.get()) {
addUkernelFunctions(*archBitcode.get());
addUkernelFunctions(*archEntryPointsBitcode.get());

// archEntryPointsBitcode contains overrides for weak symbols that will
// come in the baseBitcode below. So we link it before baseBitcode, with
// OverrideFromSrc.
StringRef archEntryPointsBitcodeName =
archEntryPointsBitcode.get()->getName();
if (failed(linkBitcodeModule(variantOp.getLoc(), moduleLinker, 0,
*targetMachine, archEntryPointsBitcodeName,
std::move(archEntryPointsBitcode),
setAlwaysInline))) {
return mlir::emitError(variantOp.getLoc())
<< "failed linking in architecture-specific ukernel entry "
"points bitcode "
"for target triple '"
<< targetTriple.str() << "'";
// archEntryPointsBitcode references symbols defined in archBitcode,
// so we link that now. We can apply LinkOnlyNeeded, since the only
// purpose of archBitcode is to satisfy references made in
// archEntryPointsBitcode.
StringRef archBitcodeName = archBitcode.get()->getName();
if (failed(linkBitcodeModule(variantOp.getLoc(), moduleLinker,
llvm::Linker::LinkOnlyNeeded,
*targetMachine, archBitcodeName,
std::move(archBitcode), {}))) {
return mlir::emitError(variantOp.getLoc())
<< "failed linking in architecture-specific ukernel bitcode "
"for target triple '"
<< targetTriple.str() << "'";
}
}

// archEntryPointsBitcode references symbols defined in archBitcode, so
// we link that now. We can apply LinkOnlyNeeded, since the only purpose
// of archBitcode is to satisfy references made in
// archEntryPointsBitcode.
StringRef archBitcodeName = archBitcode.get()->getName();
// The baseBitcode module contains weak symbols for fallbacks,
// potentially overridden by symbols defined in archEntryPointsBitcode
// above. So this must be linked after archEntryPointsBitcode. The
// baseBitcode module contains the actual ukernel entry points as seen
// from the MLIR module, and its purpose is to satisfy these references,
// so we can apply LinkOnlyNeeded here.
llvm::Expected<std::unique_ptr<llvm::Module>> baseBitcode =
loadUKernelBaseBitcode(targetMachine.get(), context);
if (baseBitcode) {
addUkernelFunctions(*baseBitcode.get());
}
// Sequence that access before we std::move(baseBitcode)!
StringRef baseBitcodeName =
baseBitcode ? baseBitcode.get()->getName() : "";
if (failed(linkBitcodeModule(
variantOp.getLoc(), moduleLinker, llvm::Linker::LinkOnlyNeeded,
*targetMachine, archBitcodeName, std::move(archBitcode), {}))) {
*targetMachine, baseBitcodeName, std::move(baseBitcode),
setAlwaysInline))) {
return mlir::emitError(variantOp.getLoc())
<< "failed linking in architecture-specific ukernel bitcode "
"for target triple '"
<< targetTriple.str() << "'";
<< "failed linking in base ukernel bitcode";
}
}

// The baseBitcode module contains weak symbols for fallbacks, potentially
// overridden by symbols defined in archEntryPointsBitcode above. So this
// must be linked after archEntryPointsBitcode.
// The baseBitcode module contains the actual ukernel entry points as seen
// from the MLIR module, and its purpose is to satisfy these references,
// so we can apply LinkOnlyNeeded here.
llvm::Expected<std::unique_ptr<llvm::Module>> baseBitcode =
loadUKernelBaseBitcode(targetMachine.get(), context);
if (baseBitcode) {
addUkernelFunctions(*baseBitcode.get());
}
// Sequence that access before we std::move(baseBitcode)!
StringRef baseBitcodeName =
baseBitcode ? baseBitcode.get()->getName() : "";
if (failed(linkBitcodeModule(variantOp.getLoc(), moduleLinker,
llvm::Linker::LinkOnlyNeeded, *targetMachine,
baseBitcodeName, std::move(baseBitcode),
setAlwaysInline))) {
return mlir::emitError(variantOp.getLoc())
<< "failed linking in base ukernel bitcode";
}
}

// Set internal linkage on all ukernel functions. No new references to
// ukernels will be created past this point, so any unreferenced ukernel
// symbol is safe to DCE, which will happen below in runLLVMIRPasses, so we
// need to set internal linkage before that.
for (auto &func : llvmModule->getFunctionList()) {
if (ukernelFunctions.count(func.getName().str())) {
func.setLinkage(llvm::GlobalValue::LinkageTypes::InternalLinkage);
// Set internal linkage on all ukernel functions. No new references to
// ukernels will be created past this point, so any unreferenced ukernel
// symbol is safe to DCE, which will happen below in runLLVMIRPasses, so
// we need to set internal linkage before that.
for (auto &func : llvmModule->getFunctionList()) {
if (ukernelFunctions.count(func.getName().str())) {
func.setLinkage(llvm::GlobalValue::LinkageTypes::InternalLinkage);
}
}
}

Expand Down

0 comments on commit fc44185

Please sign in to comment.