-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[loop count assumptions] convert loop iteration metadata to assumptions
- Loading branch information
1 parent
7a8f59c
commit 014f9f5
Showing
7 changed files
with
1,548 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
35 changes: 35 additions & 0 deletions
35
llvm/include/llvm/Transforms/Utils/LoopIterCountAssumptions.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
//===-- LoopIterCountAssumptions.h - Add loop assumptions -------*- C++ -*-===// | ||
// | ||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
// (c) Copyright 2024 Advanced Micro Devices, Inc. or its affiliates | ||
// | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// This pass converts Loop Iteration Count Metadata to Assumptions which can be | ||
// picked up by Loop Rotate to remove Loop Guards. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef LLVM_TRANSFORMS_UTILS_LOOPITERCOUNTASSUMPTIONS_H | ||
#define LLVM_TRANSFORMS_UTILS_LOOPITERCOUNTASSUMPTIONS_H | ||
#include "llvm/IR/PassManager.h" | ||
#include "llvm/Passes/PassBuilder.h" | ||
|
||
namespace llvm { | ||
|
||
class Loop; | ||
/// Converts Loop Iteration Count Metadata to Assumptions. | ||
class LoopIterCountAssumptions | ||
: public PassInfoMixin<LoopIterCountAssumptions> { | ||
|
||
public: | ||
PreservedAnalyses run(Loop &L, LoopAnalysisManager &AM, | ||
LoopStandardAnalysisResults &AR, LPMUpdater &U); | ||
}; | ||
|
||
} // namespace llvm | ||
|
||
#endif // LLVM_TRANSFORMS_UTILS_LOOPITERCOUNTASSUMPTIONS_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,243 @@ | ||
//===-- LoopIterCountAssumptions.cpp - add Loop assumptions -----*- C++ -*-===// | ||
// | ||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
// (c) Copyright 2024 Advanced Micro Devices, Inc. or its affiliates | ||
// | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// This pass converts Loop Iteration Count Metadata to Assumptions which can be | ||
// picked up by Loop Rotate to remove Loop Guards. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "llvm/Transforms/Utils/LoopIterCountAssumptions.h" | ||
#include "llvm/Analysis/AliasAnalysis.h" | ||
#include "llvm/Analysis/AssumptionCache.h" | ||
#include "llvm/Analysis/MemorySSA.h" | ||
#include "llvm/Analysis/MemorySSAUpdater.h" | ||
#include "llvm/IR/IRBuilder.h" | ||
#include "llvm/IR/Value.h" | ||
#include "llvm/Passes/PassBuilder.h" | ||
#include "llvm/Support/Casting.h" | ||
#include "llvm/Transforms/Scalar/LICM.h" | ||
#include "llvm/Transforms/Utils/LoopUtils.h" | ||
#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" | ||
|
||
#define DEBUG_TYPE "loop-iter-count-assumptions" | ||
|
||
using namespace llvm; | ||
|
||
namespace { | ||
|
||
std::string getFunctionAndBlockNames(BasicBlock &BB) { | ||
return BB.getParent()->getName().str() + " " + BB.getName().str(); | ||
} | ||
|
||
/// Return the Branch Compare Instruction of CurrentLoop if the Loop is well | ||
/// formed and this pass can process the Predicate | ||
ICmpInst *getLoopCmpInst(Loop &CurrentLoop) { | ||
|
||
if (CurrentLoop.isRotatedForm()) { | ||
LLVM_DEBUG(dbgs() << "Loop already in rotated form. Will not add Loop " | ||
"Iteration Count assumptions.\n"); | ||
return nullptr; | ||
} | ||
|
||
/// Check that the loop has a single Exiting Block. If the CurrentLoop | ||
/// contains multiple Exiting Block, ExitBB will be a nullptr | ||
auto *ExitBB = CurrentLoop.getExitingBlock(); | ||
if (!ExitBB) | ||
return nullptr; | ||
|
||
BranchInst *BI = dyn_cast<BranchInst>(ExitBB->getTerminator()); | ||
if (!BI) | ||
return nullptr; | ||
|
||
ICmpInst *LoopCmpInstr = dyn_cast<ICmpInst>(BI->getCondition()); | ||
if (!LoopCmpInstr) | ||
return nullptr; | ||
|
||
LLVM_DEBUG(dbgs() << "Condition Found: " << *LoopCmpInstr << "\n"); | ||
|
||
// Do not process equal and not-equal Compare Instructions, since they do not | ||
// work well with the currently used assumption-construction-method of reusing | ||
// the ICmpInst Predicate | ||
if (LoopCmpInstr->getPredicate() == CmpInst::ICMP_EQ || | ||
LoopCmpInstr->getPredicate() == CmpInst::ICMP_NE) { | ||
LLVM_DEBUG(dbgs() << "LoopIterCountAssumptions-Warning: EQ and NE " | ||
"conditions are currently not supported!"); | ||
return nullptr; | ||
} | ||
return LoopCmpInstr; | ||
} | ||
|
||
/// Return the AddRecExpr evaluated at Iteration 0 if an AddRecExpr can be | ||
/// extracted, otherwise return Op | ||
Value *expandValueAtZeroIteration(Value *Op, ScalarEvolution &SE, | ||
SCEVExpander &Expander, | ||
Instruction *InsertionPoint) { | ||
const SCEVAddRecExpr *AddRec = | ||
dyn_cast_or_null<SCEVAddRecExpr>(SE.getSCEV(Op)); | ||
if (!AddRec) { | ||
LLVM_DEBUG(dbgs() << "Could not extract AddRecExpr, will reuse " << *Op | ||
<< "\n"); | ||
return Op; | ||
} | ||
|
||
return Expander.expandCodeFor( | ||
AddRec->evaluateAtIteration(SE.getConstant(APInt(32, 0)), SE), | ||
Op->getType(), InsertionPoint); | ||
} | ||
|
||
Value *recursivlyCloneInBB(Value *Op, BasicBlock &BB, ValueToValueMapTy &VMap, | ||
AAResults *AA, DominatorTree *DT, Loop *CurLoop, | ||
MemorySSAUpdater &MSSAU, | ||
bool TargetExecutesOncePerLoop, | ||
SinkAndHoistLICMFlags &Flags) { | ||
if (isa<PHINode>(Op)) { | ||
LLVM_DEBUG(dbgs() << "Found end: Phi node: do not clone " << *Op << "\n"); | ||
// return Value that is in the BB, i.e. in the preheader | ||
PHINode *PN = dyn_cast<PHINode>(Op); | ||
return PN->getIncomingValueForBlock(&BB); | ||
} | ||
|
||
Instruction *I = dyn_cast<Instruction>(Op); | ||
if (!I) { | ||
LLVM_DEBUG(dbgs() << "Found end: " << *Op << "\n"); | ||
return Op; | ||
} | ||
if (!canSinkOrHoistInst(*I, AA, DT, CurLoop, MSSAU, TargetExecutesOncePerLoop, | ||
Flags)) { | ||
LLVM_DEBUG(dbgs() << "Could not hoist " << *I << "\n"); | ||
return nullptr; | ||
} | ||
|
||
auto NewI = I->clone(); | ||
NewI->insertBefore(&BB.front()); | ||
VMap[I] = NewI; | ||
LLVM_DEBUG(dbgs() << "Cloning " << *NewI << " into " << BB.getName() << "\n"); | ||
for (Value *Use : I->operands()) { | ||
LLVM_DEBUG(dbgs() << "Try cloning " << *Use << " of Instruction " << *I | ||
<< "\n"); | ||
if (!recursivlyCloneInBB(Use, BB, VMap, AA, DT, CurLoop, MSSAU, | ||
TargetExecutesOncePerLoop, Flags)) | ||
return nullptr; | ||
} | ||
llvm::RemapInstruction(NewI, VMap, | ||
RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); | ||
return NewI; | ||
} | ||
|
||
// Create the assumption into the Loop Header, that at iteration 0 the condition | ||
// is true | ||
void insertMinIterAssumption(ICmpInst &LoopCmpInstr, BasicBlock &LoopPreHeader, | ||
ScalarEvolution &SE, AssumptionCache &AC, | ||
AAResults *AA, DominatorTree *DT, | ||
Loop &CurrentLoop, MemorySSAUpdater &MSSAU, | ||
SinkAndHoistLICMFlags &LICMFlags) { | ||
Instruction *InsertionPoint = LoopPreHeader.getTerminator(); | ||
|
||
// LoopRotate uses SimplifyQuery to determine, if a Branch is conditional or | ||
// not. SimplifyQuery can only take an Assumption into account, if it is | ||
// before the to-be-evaluated Compare Instruction. Here they are inserted into | ||
// the Preheader. | ||
IRBuilder<> Builder(dyn_cast<Instruction>(InsertionPoint)); | ||
|
||
SCEVExpander Expander(SE, LoopPreHeader.getModule()->getDataLayout(), | ||
"expanded"); | ||
|
||
// If AddRecExpr of LHS is available, evaluate at Iteration 0, otherwise | ||
// return LHS | ||
Value *LHS = expandValueAtZeroIteration(LoopCmpInstr.getOperand(0), SE, | ||
Expander, InsertionPoint); | ||
ValueToValueMapTy VMapLHS; | ||
LHS = recursivlyCloneInBB(LHS, LoopPreHeader, VMapLHS, AA, DT, &CurrentLoop, | ||
MSSAU, false, LICMFlags); | ||
if (!LHS) | ||
return; | ||
|
||
LLVM_DEBUG(dbgs() << "LHS = " << *LHS << "\n"); | ||
|
||
// If AddRecExpr of RHS is available, evaluate at Iteration 0, otherwise | ||
// return RHS | ||
Value *RHS = expandValueAtZeroIteration(LoopCmpInstr.getOperand(1), SE, | ||
Expander, InsertionPoint); | ||
ValueToValueMapTy VMapRHS; | ||
RHS = recursivlyCloneInBB(RHS, LoopPreHeader, VMapRHS, AA, DT, &CurrentLoop, | ||
MSSAU, false, LICMFlags); | ||
if (!RHS) | ||
return; | ||
LLVM_DEBUG(dbgs() << "RHS = " << *RHS << "\n"); | ||
|
||
Value *Cmp = Builder.CreateICmp(LoopCmpInstr.getPredicate(), LHS, RHS); | ||
|
||
// Insert Assumption | ||
CallInst *Assumption = Builder.CreateAssumption(Cmp); | ||
AC.registerAssumption(dyn_cast<AssumeInst>(Assumption)); | ||
LLVM_DEBUG(dbgs() << "With Comparator :" << *Cmp << "\n" | ||
<< "Assume :" << *Assumption << "\n"); | ||
LLVM_DEBUG(LoopPreHeader.dump()); | ||
} | ||
|
||
/// Determine if the \param CurrentLoop is not rotated yet and Loop Iteration | ||
/// Count Metadata greater than 0 | ||
bool validLoopIterCount(Loop &CurrentLoop) { | ||
BasicBlock *LoopHeader = CurrentLoop.getHeader(); | ||
|
||
// Dump loop summary | ||
LLVM_DEBUG(if (CurrentLoop.getLoopPreheader()) { | ||
dbgs() << "Preheader:" << CurrentLoop.getLoopPreheader()->getName() << "\n"; | ||
} dbgs() << "LoopIterCountAssumption-Info: Function = " | ||
<< getFunctionAndBlockNames(*LoopHeader) << "\n"); | ||
|
||
std::optional<int64_t> RawMinIterationCount = getMinTripCount(&CurrentLoop); | ||
if (!RawMinIterationCount) { | ||
LLVM_DEBUG(dbgs() << "LoopIterCountAssumptions: Loop Iteration " | ||
"Count not provided for " | ||
<< getFunctionAndBlockNames(*LoopHeader) << "\n"); | ||
return false; | ||
} | ||
|
||
const int64_t MinIterCount = *RawMinIterationCount; | ||
if (MinIterCount <= 0) { | ||
LLVM_DEBUG(dbgs() << "LoopIterCountAssumptions-Warning: Loop Iteration " | ||
"Count is smaller or equal to zero for " | ||
<< getFunctionAndBlockNames(*LoopHeader) << "\n"); | ||
return false; | ||
} | ||
|
||
LLVM_DEBUG(dbgs() << "Processing Loop Iteration Count Metadata: " | ||
<< getFunctionAndBlockNames(*LoopHeader) << " (" | ||
<< MinIterCount << ")\n"); | ||
return true; | ||
} | ||
|
||
} // namespace | ||
|
||
PreservedAnalyses LoopIterCountAssumptions::run(Loop &CurrentLoop, | ||
LoopAnalysisManager &AM, | ||
LoopStandardAnalysisResults &AR, | ||
LPMUpdater &U) { | ||
if (!AR.MSSA) | ||
report_fatal_error( | ||
"LoopIterCountAssumptions requires MemorySSA (loop-mssa)", | ||
/*GenCrashDiag*/ false); | ||
|
||
if (!validLoopIterCount(CurrentLoop)) | ||
return PreservedAnalyses::all(); | ||
|
||
ICmpInst *LoopCmpInstr = getLoopCmpInst(CurrentLoop); | ||
if (!LoopCmpInstr) | ||
return PreservedAnalyses::all(); | ||
|
||
// check that the preheader exists and create it if necessary | ||
MemorySSAUpdater MSSAU(AR.MSSA); | ||
SinkAndHoistLICMFlags LICMFlags(false, CurrentLoop, *AR.MSSA); | ||
insertMinIterAssumption(*LoopCmpInstr, *CurrentLoop.getLoopPreheader(), AR.SE, | ||
AR.AC, &AR.AA, &AR.DT, CurrentLoop, MSSAU, LICMFlags); | ||
|
||
return PreservedAnalyses::all(); | ||
} |
Oops, something went wrong.