Skip to content

Commit

Permalink
[LoopVectorizer] Add support for partial reductions
Browse files Browse the repository at this point in the history
  • Loading branch information
NickGuy-Arm committed May 17, 2024
1 parent 7f96074 commit 9de3c24
Show file tree
Hide file tree
Showing 12 changed files with 291 additions and 13 deletions.
10 changes: 10 additions & 0 deletions llvm/include/llvm/IR/DerivedTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,16 @@ class VectorType : public Type {
EltCnt.divideCoefficientBy(2));
}

/// This static method returns a VectorType with quarter as many elements as the
/// input type and the same element type.
static VectorType *getQuarterElementsVectorType(VectorType *VTy) {
auto EltCnt = VTy->getElementCount();
assert(EltCnt.isKnownEven() &&
"Cannot halve vector with odd number of elements.");
return VectorType::get(VTy->getElementType(),
EltCnt.divideCoefficientBy(4));
}

/// This static method returns a VectorType with twice as many elements as the
/// input type and the same element type.
static VectorType *getDoubleElementsVectorType(VectorType *VTy) {
Expand Down
5 changes: 3 additions & 2 deletions llvm/include/llvm/IR/Intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ namespace Intrinsic {
ExtendArgument,
TruncArgument,
HalfVecArgument,
QuarterVecArgument,
SameVecWidthArgument,
VecOfAnyPtrsToElt,
VecElementArgument,
Expand Down Expand Up @@ -160,15 +161,15 @@ namespace Intrinsic {

unsigned getArgumentNumber() const {
assert(Kind == Argument || Kind == ExtendArgument ||
Kind == TruncArgument || Kind == HalfVecArgument ||
Kind == TruncArgument || Kind == HalfVecArgument || Kind == QuarterVecArgument ||
Kind == SameVecWidthArgument || Kind == VecElementArgument ||
Kind == Subdivide2Argument || Kind == Subdivide4Argument ||
Kind == VecOfBitcastsToInt);
return Argument_Info >> 3;
}
ArgKind getArgumentKind() const {
assert(Kind == Argument || Kind == ExtendArgument ||
Kind == TruncArgument || Kind == HalfVecArgument ||
Kind == TruncArgument || Kind == HalfVecArgument || Kind == QuarterVecArgument ||
Kind == SameVecWidthArgument ||
Kind == VecElementArgument || Kind == Subdivide2Argument ||
Kind == Subdivide4Argument || Kind == VecOfBitcastsToInt);
Expand Down
10 changes: 10 additions & 0 deletions llvm/include/llvm/IR/Intrinsics.td
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ def IIT_I4 : IIT_Int<4, 58>;
def IIT_AARCH64_SVCOUNT : IIT_VT<aarch64svcount, 59>;
def IIT_V6 : IIT_Vec<6, 60>;
def IIT_V10 : IIT_Vec<10, 61>;
def IIT_QUARTER_VEC_ARG : IIT_Base<62>;
}

defvar IIT_all_FixedTypes = !filter(iit, IIT_all,
Expand Down Expand Up @@ -457,6 +458,9 @@ class LLVMVectorElementType<int num> : LLVMMatchType<num, IIT_VEC_ELEMENT>;
class LLVMHalfElementsVectorType<int num>
: LLVMMatchType<num, IIT_HALF_VEC_ARG>;

class LLVMQuarterElementsVectorType<int num>
: LLVMMatchType<num, IIT_QUARTER_VEC_ARG>;

// Match the type of another intrinsic parameter that is expected to be a
// vector type (i.e. <N x iM>) but with each element subdivided to
// form a vector with more elements that are smaller than the original.
Expand Down Expand Up @@ -2605,6 +2609,12 @@ def int_experimental_vector_deinterleave2 : DefaultAttrsIntrinsic<[LLVMHalfEleme
[llvm_anyvector_ty],
[IntrNoMem]>;

//===-------------- Intrinsics to perform partial reduction ---------------===//

def int_experimental_vector_partial_reduce_add : DefaultAttrsIntrinsic<[LLVMQuarterElementsVectorType<0>],
[llvm_anyvector_ty],
[IntrNoMem]>;

//===----------------- Pointer Authentication Intrinsics ------------------===//
//

Expand Down
16 changes: 16 additions & 0 deletions llvm/lib/IR/Function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1240,6 +1240,12 @@ static void DecodeIITType(unsigned &NextElt, ArrayRef<unsigned char> Infos,
ArgInfo));
return;
}
case IIT_QUARTER_VEC_ARG: {
unsigned ArgInfo = (NextElt == Infos.size() ? 0 : Infos[NextElt++]);
OutputTable.push_back(IITDescriptor::get(IITDescriptor::QuarterVecArgument,
ArgInfo));
return;
}
case IIT_SAME_VEC_WIDTH_ARG: {
unsigned ArgInfo = (NextElt == Infos.size() ? 0 : Infos[NextElt++]);
OutputTable.push_back(IITDescriptor::get(IITDescriptor::SameVecWidthArgument,
Expand Down Expand Up @@ -1404,6 +1410,9 @@ static Type *DecodeFixedType(ArrayRef<Intrinsic::IITDescriptor> &Infos,
case IITDescriptor::HalfVecArgument:
return VectorType::getHalfElementsVectorType(cast<VectorType>(
Tys[D.getArgumentNumber()]));
case IITDescriptor::QuarterVecArgument: {
return VectorType::getQuarterElementsVectorType(cast<VectorType>(Tys[D.getArgumentNumber()]));
}
case IITDescriptor::SameVecWidthArgument: {
Type *EltTy = DecodeFixedType(Infos, Tys, Context);
Type *Ty = Tys[D.getArgumentNumber()];
Expand Down Expand Up @@ -1619,6 +1628,13 @@ static bool matchIntrinsicType(
return !isa<VectorType>(ArgTys[D.getArgumentNumber()]) ||
VectorType::getHalfElementsVectorType(
cast<VectorType>(ArgTys[D.getArgumentNumber()])) != Ty;
case IITDescriptor::QuarterVecArgument: {
if (D.getArgumentNumber() >= ArgTys.size())
return IsDeferredCheck || DeferCheck(Ty);
return !isa<VectorType>(ArgTys[D.getArgumentNumber()]) ||
VectorType::getQuarterElementsVectorType(
cast<VectorType>(ArgTys[D.getArgumentNumber()])) != Ty;
}
case IITDescriptor::SameVecWidthArgument: {
if (D.getArgumentNumber() >= ArgTys.size()) {
// Defer check and subsequent check for the vector element type.
Expand Down
122 changes: 122 additions & 0 deletions llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2203,6 +2203,92 @@ static bool useActiveLaneMaskForControlFlow(TailFoldingStyle Style) {
Style == TailFoldingStyle::DataAndControlFlowWithoutRuntimeCheck;
}

static void getPartialReductionInstrChain(Instruction *Instr, SmallVector<Value*, 4> &Chain) {
Instruction *Mul = cast<Instruction>(Instr->getOperand(0));
Instruction *Ext0 = cast<ZExtInst>(Mul->getOperand(0));
Instruction *Ext1 = cast<ZExtInst>(Mul->getOperand(1));

Chain.push_back(Mul);
Chain.push_back(Ext0);
Chain.push_back(Ext1);
Chain.push_back(Instr->getOperand(1));
}


/// @param Instr The root instruction to scan
static bool isInstrPartialReduction(Instruction *Instr) {
Value *ExpectedPhi;
Value *A, *B;
Value *InductionA, *InductionB;

using namespace llvm::PatternMatch;
auto Pattern = m_Add(
m_OneUse(m_Mul(
m_OneUse(m_ZExt(
m_OneUse(m_Load(
m_GEP(
m_Value(A),
m_Value(InductionA)))))),
m_OneUse(m_ZExt(
m_OneUse(m_Load(
m_GEP(
m_Value(B),
m_Value(InductionB))))))
)), m_Value(ExpectedPhi));

bool Matches = match(Instr, Pattern);

if(!Matches)
return false;

// Check that the two induction variable uses are to the same induction variable
if(InductionA != InductionB) {
LLVM_DEBUG(dbgs() << "Loop uses different induction variables for each input variable, cannot create a partial reduction.\n");
return false;
}

Instruction *Mul = cast<Instruction>(Instr->getOperand(0));
Instruction *Ext0 = cast<ZExtInst>(Mul->getOperand(0));
Instruction *Ext1 = cast<ZExtInst>(Mul->getOperand(1));

// Check that the extends extend to i32
if(!Ext0->getType()->isIntegerTy(32) || !Ext1->getType()->isIntegerTy(32)) {
LLVM_DEBUG(dbgs() << "Extends don't extend to the correct width, cannot create a partial reduction.\n");
return false;
}

// Check that the loads are loading i8
LoadInst *Load0 = cast<LoadInst>(Ext0->getOperand(0));
LoadInst *Load1 = cast<LoadInst>(Ext1->getOperand(0));
if(!Load0->getType()->isIntegerTy(8) || !Load1->getType()->isIntegerTy(8)) {
LLVM_DEBUG(dbgs() << "Loads don't load the correct width, cannot create a partial reduction\n");
return false;
}

// Check that the add feeds into ExpectedPhi
PHINode *PhiNode = dyn_cast<PHINode>(ExpectedPhi);
if(!PhiNode) {
LLVM_DEBUG(dbgs() << "Expected Phi node was not a phi, cannot create a partial reduction.\n");
return false;
}

// Check that the first phi value is a zero initializer
ConstantInt *ZeroInit = dyn_cast<ConstantInt>(PhiNode->getIncomingValue(0));
if(!ZeroInit || !ZeroInit->isZero()) {
LLVM_DEBUG(dbgs() << "First PHI value is not a constant zero, cannot create a partial reduction.\n");
return false;
}

// Check that the second phi value is the instruction we're looking at
Instruction *MaybeAdd = dyn_cast<Instruction>(PhiNode->getIncomingValue(1));
if(!MaybeAdd || MaybeAdd != Instr) {
LLVM_DEBUG(dbgs() << "Second PHI value is not the root add, cannot create a partial reduction.\n");
return false;
}

return true;
}

// Return true if \p OuterLp is an outer loop annotated with hints for explicit
// vectorization. The loop needs to be annotated with #pragma omp simd
// simdlen(#) or #pragma clang vectorize(enable) vectorize_width(#). If the
Expand Down Expand Up @@ -5084,6 +5170,13 @@ bool LoopVectorizationPlanner::isCandidateForEpilogueVectorization(
return false;
}

// Prevent epilogue vectorization if a partial reduction is involved
// TODO Is there a cleaner way to check this?
if(any_of(Legal->getReductionVars(), [&](const std::pair<PHINode *, RecurrenceDescriptor> &Reduction) {
return isInstrPartialReduction(Reduction.second.getLoopExitInstr());
}))
return false;

// Epilogue vectorization code has not been auditted to ensure it handles
// non-latch exits properly. It may be fine, but it needs auditted and
// tested.
Expand Down Expand Up @@ -7182,6 +7275,17 @@ void LoopVectorizationCostModel::collectValuesToIgnore() {
const SmallVectorImpl<Instruction *> &Casts = IndDes.getCastInsts();
VecValuesToIgnore.insert(Casts.begin(), Casts.end());
}

// Ignore any values that we know will be flattened
for(auto Reduction : this->Legal->getReductionVars()) {
auto &Recurrence = Reduction.second;
if(isInstrPartialReduction(Recurrence.getLoopExitInstr())) {
SmallVector<Value*, 4> PartialReductionValues;
getPartialReductionInstrChain(Recurrence.getLoopExitInstr(), PartialReductionValues);
ValuesToIgnore.insert(PartialReductionValues.begin(), PartialReductionValues.end());
VecValuesToIgnore.insert(PartialReductionValues.begin(), PartialReductionValues.end());
}
}
}

void LoopVectorizationCostModel::collectInLoopReductions() {
Expand Down Expand Up @@ -8536,9 +8640,24 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
*CI);
}

if(auto *PartialReduce = tryToCreatePartialReduction(Range, Instr, Operands))
return PartialReduce;

return tryToWiden(Instr, Operands, VPBB);
}

VPRecipeBase *VPRecipeBuilder::tryToCreatePartialReduction(
VFRange &Range, Instruction *Instr, ArrayRef<VPValue *> Operands) {

if(isInstrPartialReduction(Instr)) {
auto EC = ElementCount::getScalable(16);
if(std::find(Range.begin(), Range.end(), EC) == Range.end())
return nullptr;
return new VPPartialReductionRecipe(*Instr, make_range(Operands.begin(), Operands.end()));
}
return nullptr;
}

void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
ElementCount MaxVF) {
assert(OrigLoop->isInnermost() && "Inner loop expected.");
Expand Down Expand Up @@ -8746,6 +8865,9 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
VPBB->appendRecipe(Recipe);
}

for(auto &Recipe : *VPBB)
Recipe.postInsertionOp();

VPBlockUtils::insertBlockAfter(new VPBasicBlock(), VPBB);
VPBB = cast<VPBasicBlock>(VPBB->getSingleSuccessor());
}
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ class VPRecipeBuilder {
ArrayRef<VPValue *> Operands,
VFRange &Range, VPBasicBlock *VPBB);

VPRecipeBase* tryToCreatePartialReduction(VFRange &Range, Instruction* Instr, ArrayRef<VPValue*> Operands);

/// Set the recipe created for given ingredient.
void setRecipe(Instruction *I, VPRecipeBase *R) {
assert(!Ingredient2Recipe.contains(I) &&
Expand Down
43 changes: 40 additions & 3 deletions llvm/lib/Transforms/Vectorize/VPlan.h
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,8 @@ class VPRecipeBase : public ilist_node_with_parent<VPRecipeBase, VPBasicBlock>,
/// \returns an iterator pointing to the element after the erased one
iplist<VPRecipeBase>::iterator eraseFromParent();

virtual void postInsertionOp() {}

/// Method to support type inquiry through isa, cast, and dyn_cast.
static inline bool classof(const VPDef *D) {
// All VPDefs are also VPRecipeBases.
Expand Down Expand Up @@ -1881,14 +1883,19 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe {
/// The phi is part of an ordered reduction. Requires IsInLoop to be true.
bool IsOrdered;

/// The amount that the VF should be divided by during ::execute
unsigned VFScaleFactor = 1;

public:

/// Create a new VPReductionPHIRecipe for the reduction \p Phi described by \p
/// RdxDesc.
VPReductionPHIRecipe(PHINode *Phi, const RecurrenceDescriptor &RdxDesc,
VPValue &Start, bool IsInLoop = false,
bool IsOrdered = false)
bool IsOrdered = false, unsigned VFScaleFactor = 1)
: VPHeaderPHIRecipe(VPDef::VPReductionPHISC, Phi, &Start),
RdxDesc(RdxDesc), IsInLoop(IsInLoop), IsOrdered(IsOrdered) {
RdxDesc(RdxDesc), IsInLoop(IsInLoop), IsOrdered(IsOrdered),
VFScaleFactor(VFScaleFactor) {
assert((!IsOrdered || IsInLoop) && "IsOrdered requires IsInLoop");
}

Expand All @@ -1897,7 +1904,7 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe {
VPReductionPHIRecipe *clone() override {
auto *R =
new VPReductionPHIRecipe(cast<PHINode>(getUnderlyingInstr()), RdxDesc,
*getOperand(0), IsInLoop, IsOrdered);
*getOperand(0), IsInLoop, IsOrdered, VFScaleFactor);
R->addOperand(getBackedgeValue());
return R;
}
Expand All @@ -1908,6 +1915,10 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe {
return R->getVPDefID() == VPDef::VPReductionPHISC;
}

void SetVFScaleFactor(unsigned ScaleFactor) {
VFScaleFactor = ScaleFactor;
}

/// Generate the phi/select nodes.
void execute(VPTransformState &State) override;

Expand All @@ -1928,6 +1939,32 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe {
bool isInLoop() const { return IsInLoop; }
};

class VPPartialReductionRecipe : public VPRecipeWithIRFlags {
unsigned Opcode;
public:
template <typename IterT>
VPPartialReductionRecipe(Instruction &I,
iterator_range<IterT> Operands) : VPRecipeWithIRFlags(
VPDef::VPPartialReductionSC, Operands, I), Opcode(I.getOpcode())
{}
~VPPartialReductionRecipe() override = default;
VPPartialReductionRecipe *clone() override {
auto *R = new VPPartialReductionRecipe(*getUnderlyingInstr(), operands());
R->transferFlags(*this);
return R;
}
VP_CLASSOF_IMPL(VPDef::VPPartialReductionSC)
/// Generate the reduction in the loop
void execute(VPTransformState &State) override;
void postInsertionOp() override;
unsigned getOpcode() { return Opcode; }
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
/// Print the recipe.
void print(raw_ostream &O, const Twine &Indent,
VPSlotTracker &SlotTracker) const override;
#endif
};

/// A recipe for vectorizing a phi-node as a sequence of mask-based select
/// instructions.
class VPBlendRecipe : public VPSingleDefRecipe {
Expand Down
6 changes: 5 additions & 1 deletion llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,10 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPReplicateRecipe *R) {
llvm_unreachable("Unhandled opcode");
}

Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPPartialReductionRecipe *R) {
return R->getUnderlyingInstr()->getType();
}

Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
if (Type *CachedTy = CachedTypes.lookup(V))
return CachedTy;
Expand Down Expand Up @@ -238,7 +242,7 @@ Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
return inferScalarType(R->getOperand(0));
})
.Case<VPBlendRecipe, VPInstruction, VPWidenRecipe, VPReplicateRecipe,
VPWidenCallRecipe, VPWidenMemoryRecipe, VPWidenSelectRecipe>(
VPWidenCallRecipe, VPWidenMemoryRecipe, VPWidenSelectRecipe, VPPartialReductionRecipe>(
[this](const auto *R) { return inferScalarTypeForRecipe(R); })
.Case<VPInterleaveRecipe>([V](const VPInterleaveRecipe *R) {
// TODO: Use info from interleave group.
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class VPWidenIntOrFpInductionRecipe;
class VPWidenMemoryRecipe;
struct VPWidenSelectRecipe;
class VPReplicateRecipe;
class VPPartialReductionRecipe;
class Type;

/// An analysis for type-inference for VPValues.
Expand All @@ -49,6 +50,7 @@ class VPTypeAnalysis {
Type *inferScalarTypeForRecipe(const VPWidenMemoryRecipe *R);
Type *inferScalarTypeForRecipe(const VPWidenSelectRecipe *R);
Type *inferScalarTypeForRecipe(const VPReplicateRecipe *R);
Type *inferScalarTypeForRecipe(const VPPartialReductionRecipe *R);

public:
VPTypeAnalysis(Type *CanonicalIVTy, LLVMContext &Ctx)
Expand Down
Loading

0 comments on commit 9de3c24

Please sign in to comment.