Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoopVectorizer] Add support for partial reductions #92418

Open
wants to merge 50 commits into
base: main
Choose a base branch
from

Conversation

NickGuy-Arm
Copy link
Contributor

@NickGuy-Arm NickGuy-Arm commented May 16, 2024

Following on from #94499, this patch adds support to the Loop Vectorizer to emit the partial reduction intrinsics where they may be beneficial for the target.

@llvmbot
Copy link
Member

llvmbot commented May 16, 2024

@llvm/pr-subscribers-llvm-analysis

@llvm/pr-subscribers-backend-aarch64

Author: None (NickGuy-Arm)

Changes

This patch adds to the loop vectorizer support for partial reductions; that is a reduction from a wider vector to a narrower vector.


Patch is 29.88 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/92418.diff

12 Files Affected:

  • (modified) llvm/include/llvm/IR/DerivedTypes.h (+10)
  • (modified) llvm/include/llvm/IR/Intrinsics.h (+3-2)
  • (modified) llvm/include/llvm/IR/Intrinsics.td (+10)
  • (modified) llvm/lib/IR/Function.cpp (+16)
  • (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+122)
  • (modified) llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h (+2)
  • (modified) llvm/lib/Transforms/Vectorize/VPlan.h (+40-3)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp (+5-1)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanAnalysis.h (+2)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+76-4)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanValue.h (+1)
  • (added) llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll (+100)
diff --git a/llvm/include/llvm/IR/DerivedTypes.h b/llvm/include/llvm/IR/DerivedTypes.h
index 443fb7de3b821..866a01c9afebd 100644
--- a/llvm/include/llvm/IR/DerivedTypes.h
+++ b/llvm/include/llvm/IR/DerivedTypes.h
@@ -512,6 +512,16 @@ class VectorType : public Type {
                            EltCnt.divideCoefficientBy(2));
   }
 
+  /// This static method returns a VectorType with quarter as many elements as the
+  /// input type and the same element type.
+  static VectorType *getQuarterElementsVectorType(VectorType *VTy) {
+    auto EltCnt = VTy->getElementCount();
+    assert(EltCnt.isKnownEven() &&
+           "Cannot halve vector with odd number of elements.");
+    return VectorType::get(VTy->getElementType(),
+                           EltCnt.divideCoefficientBy(4));
+  }
+
   /// This static method returns a VectorType with twice as many elements as the
   /// input type and the same element type.
   static VectorType *getDoubleElementsVectorType(VectorType *VTy) {
diff --git a/llvm/include/llvm/IR/Intrinsics.h b/llvm/include/llvm/IR/Intrinsics.h
index 340c1c326d066..e03e7e0bf50de 100644
--- a/llvm/include/llvm/IR/Intrinsics.h
+++ b/llvm/include/llvm/IR/Intrinsics.h
@@ -131,6 +131,7 @@ namespace Intrinsic {
       ExtendArgument,
       TruncArgument,
       HalfVecArgument,
+      QuarterVecArgument,
       SameVecWidthArgument,
       VecOfAnyPtrsToElt,
       VecElementArgument,
@@ -160,7 +161,7 @@ namespace Intrinsic {
 
     unsigned getArgumentNumber() const {
       assert(Kind == Argument || Kind == ExtendArgument ||
-             Kind == TruncArgument || Kind == HalfVecArgument ||
+             Kind == TruncArgument || Kind == HalfVecArgument || Kind == QuarterVecArgument ||
              Kind == SameVecWidthArgument || Kind == VecElementArgument ||
              Kind == Subdivide2Argument || Kind == Subdivide4Argument ||
              Kind == VecOfBitcastsToInt);
@@ -168,7 +169,7 @@ namespace Intrinsic {
     }
     ArgKind getArgumentKind() const {
       assert(Kind == Argument || Kind == ExtendArgument ||
-             Kind == TruncArgument || Kind == HalfVecArgument ||
+             Kind == TruncArgument || Kind == HalfVecArgument || Kind == QuarterVecArgument ||
              Kind == SameVecWidthArgument ||
              Kind == VecElementArgument || Kind == Subdivide2Argument ||
              Kind == Subdivide4Argument || Kind == VecOfBitcastsToInt);
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index 1d20f7e1b1985..dad177e595341 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -321,6 +321,7 @@ def IIT_I4 : IIT_Int<4, 58>;
 def IIT_AARCH64_SVCOUNT : IIT_VT<aarch64svcount, 59>;
 def IIT_V6 : IIT_Vec<6, 60>;
 def IIT_V10 : IIT_Vec<10, 61>;
+def IIT_QUARTER_VEC_ARG : IIT_Base<62>;
 }
 
 defvar IIT_all_FixedTypes = !filter(iit, IIT_all,
@@ -457,6 +458,9 @@ class LLVMVectorElementType<int num> : LLVMMatchType<num, IIT_VEC_ELEMENT>;
 class LLVMHalfElementsVectorType<int num>
   : LLVMMatchType<num, IIT_HALF_VEC_ARG>;
 
+class LLVMQuarterElementsVectorType<int num>
+  : LLVMMatchType<num, IIT_QUARTER_VEC_ARG>;
+
 // Match the type of another intrinsic parameter that is expected to be a
 // vector type (i.e. <N x iM>) but with each element subdivided to
 // form a vector with more elements that are smaller than the original.
@@ -2605,6 +2609,12 @@ def int_experimental_vector_deinterleave2 : DefaultAttrsIntrinsic<[LLVMHalfEleme
                                                                   [llvm_anyvector_ty],
                                                                   [IntrNoMem]>;
 
+//===-------------- Intrinsics to perform partial reduction ---------------===//
+
+def int_experimental_vector_partial_reduce_add : DefaultAttrsIntrinsic<[LLVMQuarterElementsVectorType<0>],
+                                                                       [llvm_anyvector_ty],
+                                                                       [IntrNoMem]>;
+
 //===----------------- Pointer Authentication Intrinsics ------------------===//
 //
 
diff --git a/llvm/lib/IR/Function.cpp b/llvm/lib/IR/Function.cpp
index e66fe73425e86..e9eebd5e35300 100644
--- a/llvm/lib/IR/Function.cpp
+++ b/llvm/lib/IR/Function.cpp
@@ -1240,6 +1240,12 @@ static void DecodeIITType(unsigned &NextElt, ArrayRef<unsigned char> Infos,
                                              ArgInfo));
     return;
   }
+  case IIT_QUARTER_VEC_ARG: {
+    unsigned ArgInfo = (NextElt == Infos.size() ? 0 : Infos[NextElt++]);
+    OutputTable.push_back(IITDescriptor::get(IITDescriptor::QuarterVecArgument,
+                                             ArgInfo));
+    return;
+  }
   case IIT_SAME_VEC_WIDTH_ARG: {
     unsigned ArgInfo = (NextElt == Infos.size() ? 0 : Infos[NextElt++]);
     OutputTable.push_back(IITDescriptor::get(IITDescriptor::SameVecWidthArgument,
@@ -1404,6 +1410,9 @@ static Type *DecodeFixedType(ArrayRef<Intrinsic::IITDescriptor> &Infos,
   case IITDescriptor::HalfVecArgument:
     return VectorType::getHalfElementsVectorType(cast<VectorType>(
                                                   Tys[D.getArgumentNumber()]));
+  case IITDescriptor::QuarterVecArgument:  {
+    return VectorType::getQuarterElementsVectorType(cast<VectorType>(Tys[D.getArgumentNumber()]));
+  }
   case IITDescriptor::SameVecWidthArgument: {
     Type *EltTy = DecodeFixedType(Infos, Tys, Context);
     Type *Ty = Tys[D.getArgumentNumber()];
@@ -1619,6 +1628,13 @@ static bool matchIntrinsicType(
       return !isa<VectorType>(ArgTys[D.getArgumentNumber()]) ||
              VectorType::getHalfElementsVectorType(
                      cast<VectorType>(ArgTys[D.getArgumentNumber()])) != Ty;
+    case IITDescriptor::QuarterVecArgument: {
+    if (D.getArgumentNumber() >= ArgTys.size())
+        return IsDeferredCheck || DeferCheck(Ty);
+      return !isa<VectorType>(ArgTys[D.getArgumentNumber()]) ||
+             VectorType::getQuarterElementsVectorType(
+                     cast<VectorType>(ArgTys[D.getArgumentNumber()])) != Ty;
+    }
     case IITDescriptor::SameVecWidthArgument: {
       if (D.getArgumentNumber() >= ArgTys.size()) {
         // Defer check and subsequent check for the vector element type.
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 33c4decd58a6c..1f37df061bbf7 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -2203,6 +2203,92 @@ static bool useActiveLaneMaskForControlFlow(TailFoldingStyle Style) {
          Style == TailFoldingStyle::DataAndControlFlowWithoutRuntimeCheck;
 }
 
+static void getPartialReductionInstrChain(Instruction *Instr, SmallVector<Value*, 4> &Chain) {
+  Instruction *Mul = cast<Instruction>(Instr->getOperand(0));
+  Instruction *Ext0 = cast<ZExtInst>(Mul->getOperand(0));
+  Instruction *Ext1 = cast<ZExtInst>(Mul->getOperand(1));
+
+  Chain.push_back(Mul);
+  Chain.push_back(Ext0);
+  Chain.push_back(Ext1);
+  Chain.push_back(Instr->getOperand(1));
+}
+
+
+/// @param Instr The root instruction to scan
+static bool isInstrPartialReduction(Instruction *Instr) {
+  Value *ExpectedPhi;
+  Value *A, *B;
+  Value *InductionA, *InductionB;
+
+  using namespace llvm::PatternMatch;
+  auto Pattern = m_Add(
+    m_OneUse(m_Mul(
+      m_OneUse(m_ZExt(
+        m_OneUse(m_Load(
+          m_GEP(
+              m_Value(A),
+              m_Value(InductionA)))))),
+      m_OneUse(m_ZExt(
+        m_OneUse(m_Load(
+          m_GEP(
+              m_Value(B),
+              m_Value(InductionB))))))
+        )), m_Value(ExpectedPhi));
+
+  bool Matches = match(Instr, Pattern);
+
+  if(!Matches)
+    return false;
+
+  // Check that the two induction variable uses are to the same induction variable
+  if(InductionA != InductionB) {
+    LLVM_DEBUG(dbgs() << "Loop uses different induction variables for each input variable, cannot create a partial reduction.\n");
+    return false;
+  }
+
+  Instruction *Mul = cast<Instruction>(Instr->getOperand(0));
+  Instruction *Ext0 = cast<ZExtInst>(Mul->getOperand(0));
+  Instruction *Ext1 = cast<ZExtInst>(Mul->getOperand(1));
+
+  // Check that the extends extend to i32
+  if(!Ext0->getType()->isIntegerTy(32) || !Ext1->getType()->isIntegerTy(32)) {
+    LLVM_DEBUG(dbgs() << "Extends don't extend to the correct width, cannot create a partial reduction.\n");
+    return false;
+  }
+
+  // Check that the loads are loading i8
+  LoadInst *Load0 = cast<LoadInst>(Ext0->getOperand(0));
+  LoadInst *Load1 = cast<LoadInst>(Ext1->getOperand(0));
+  if(!Load0->getType()->isIntegerTy(8) || !Load1->getType()->isIntegerTy(8)) {
+    LLVM_DEBUG(dbgs() << "Loads don't load the correct width, cannot create a partial reduction\n");
+    return false;
+  }
+
+  // Check that the add feeds into ExpectedPhi
+  PHINode *PhiNode = dyn_cast<PHINode>(ExpectedPhi);
+  if(!PhiNode) {
+    LLVM_DEBUG(dbgs() << "Expected Phi node was not a phi, cannot create a partial reduction.\n");
+    return false;
+  }
+
+  // Check that the first phi value is a zero initializer
+  ConstantInt *ZeroInit = dyn_cast<ConstantInt>(PhiNode->getIncomingValue(0));
+  if(!ZeroInit || !ZeroInit->isZero()) {
+    LLVM_DEBUG(dbgs() << "First PHI value is not a constant zero, cannot create a partial reduction.\n");
+    return false;
+  }
+
+  // Check that the second phi value is the instruction we're looking at
+  Instruction *MaybeAdd = dyn_cast<Instruction>(PhiNode->getIncomingValue(1));
+  if(!MaybeAdd || MaybeAdd != Instr) {
+    LLVM_DEBUG(dbgs() << "Second PHI value is not the root add, cannot create a partial reduction.\n");
+    return false;
+  }
+
+  return true;
+}
+
 // Return true if \p OuterLp is an outer loop annotated with hints for explicit
 // vectorization. The loop needs to be annotated with #pragma omp simd
 // simdlen(#) or #pragma clang vectorize(enable) vectorize_width(#). If the
@@ -5084,6 +5170,13 @@ bool LoopVectorizationPlanner::isCandidateForEpilogueVectorization(
         return false;
   }
 
+  // Prevent epilogue vectorization if a partial reduction is involved
+  // TODO Is there a cleaner way to check this?
+  if(any_of(Legal->getReductionVars(), [&](const std::pair<PHINode *, RecurrenceDescriptor> &Reduction) {
+    return isInstrPartialReduction(Reduction.second.getLoopExitInstr());
+  }))
+    return false;
+
   // Epilogue vectorization code has not been auditted to ensure it handles
   // non-latch exits properly.  It may be fine, but it needs auditted and
   // tested.
@@ -7182,6 +7275,17 @@ void LoopVectorizationCostModel::collectValuesToIgnore() {
     const SmallVectorImpl<Instruction *> &Casts = IndDes.getCastInsts();
     VecValuesToIgnore.insert(Casts.begin(), Casts.end());
   }
+
+  // Ignore any values that we know will be flattened
+  for(auto Reduction : this->Legal->getReductionVars()) {
+    auto &Recurrence = Reduction.second;
+    if(isInstrPartialReduction(Recurrence.getLoopExitInstr())) {
+      SmallVector<Value*, 4> PartialReductionValues;
+      getPartialReductionInstrChain(Recurrence.getLoopExitInstr(), PartialReductionValues);
+      ValuesToIgnore.insert(PartialReductionValues.begin(), PartialReductionValues.end());
+      VecValuesToIgnore.insert(PartialReductionValues.begin(), PartialReductionValues.end());
+    }
+  }
 }
 
 void LoopVectorizationCostModel::collectInLoopReductions() {
@@ -8536,9 +8640,24 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
                                  *CI);
   }
 
+  if(auto *PartialReduce = tryToCreatePartialReduction(Range, Instr, Operands))
+    return PartialReduce;
+
   return tryToWiden(Instr, Operands, VPBB);
 }
 
+VPRecipeBase *VPRecipeBuilder::tryToCreatePartialReduction(
+    VFRange &Range, Instruction *Instr, ArrayRef<VPValue *> Operands) {
+
+  if(isInstrPartialReduction(Instr)) {
+    auto EC = ElementCount::getScalable(16);
+    if(std::find(Range.begin(), Range.end(), EC) == Range.end())
+      return nullptr;
+    return new VPPartialReductionRecipe(*Instr, make_range(Operands.begin(), Operands.end()));
+  }
+  return nullptr;
+}
+
 void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
                                                         ElementCount MaxVF) {
   assert(OrigLoop->isInnermost() && "Inner loop expected.");
@@ -8746,6 +8865,9 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
         VPBB->appendRecipe(Recipe);
     }
 
+    for(auto &Recipe : *VPBB)
+      Recipe.postInsertionOp();
+
     VPBlockUtils::insertBlockAfter(new VPBasicBlock(), VPBB);
     VPBB = cast<VPBasicBlock>(VPBB->getSingleSuccessor());
   }
diff --git a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
index b4c7ab02f928f..c439f221709e1 100644
--- a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
+++ b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
@@ -116,6 +116,8 @@ class VPRecipeBuilder {
                                        ArrayRef<VPValue *> Operands,
                                        VFRange &Range, VPBasicBlock *VPBB);
 
+  VPRecipeBase* tryToCreatePartialReduction(VFRange &Range, Instruction* Instr, ArrayRef<VPValue*> Operands);
+
   /// Set the recipe created for given ingredient.
   void setRecipe(Instruction *I, VPRecipeBase *R) {
     assert(!Ingredient2Recipe.contains(I) &&
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index c74329a0bcc4a..5a572ecb798d6 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -767,6 +767,8 @@ class VPRecipeBase : public ilist_node_with_parent<VPRecipeBase, VPBasicBlock>,
   /// \returns an iterator pointing to the element after the erased one
   iplist<VPRecipeBase>::iterator eraseFromParent();
 
+  virtual void postInsertionOp() {}
+
   /// Method to support type inquiry through isa, cast, and dyn_cast.
   static inline bool classof(const VPDef *D) {
     // All VPDefs are also VPRecipeBases.
@@ -1881,14 +1883,19 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe {
   /// The phi is part of an ordered reduction. Requires IsInLoop to be true.
   bool IsOrdered;
 
+  /// The amount that the VF should be divided by during ::execute
+  unsigned VFScaleFactor = 1;
+
 public:
+
   /// Create a new VPReductionPHIRecipe for the reduction \p Phi described by \p
   /// RdxDesc.
   VPReductionPHIRecipe(PHINode *Phi, const RecurrenceDescriptor &RdxDesc,
                        VPValue &Start, bool IsInLoop = false,
-                       bool IsOrdered = false)
+                       bool IsOrdered = false, unsigned VFScaleFactor = 1)
       : VPHeaderPHIRecipe(VPDef::VPReductionPHISC, Phi, &Start),
-        RdxDesc(RdxDesc), IsInLoop(IsInLoop), IsOrdered(IsOrdered) {
+        RdxDesc(RdxDesc), IsInLoop(IsInLoop), IsOrdered(IsOrdered),
+        VFScaleFactor(VFScaleFactor) {
     assert((!IsOrdered || IsInLoop) && "IsOrdered requires IsInLoop");
   }
 
@@ -1897,7 +1904,7 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe {
   VPReductionPHIRecipe *clone() override {
     auto *R =
         new VPReductionPHIRecipe(cast<PHINode>(getUnderlyingInstr()), RdxDesc,
-                                 *getOperand(0), IsInLoop, IsOrdered);
+                                 *getOperand(0), IsInLoop, IsOrdered, VFScaleFactor);
     R->addOperand(getBackedgeValue());
     return R;
   }
@@ -1908,6 +1915,10 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe {
     return R->getVPDefID() == VPDef::VPReductionPHISC;
   }
 
+  void SetVFScaleFactor(unsigned ScaleFactor) {
+    VFScaleFactor = ScaleFactor;
+  }
+
   /// Generate the phi/select nodes.
   void execute(VPTransformState &State) override;
 
@@ -1928,6 +1939,32 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe {
   bool isInLoop() const { return IsInLoop; }
 };
 
+class VPPartialReductionRecipe : public VPRecipeWithIRFlags {
+  unsigned Opcode;
+public:
+  template <typename IterT>
+  VPPartialReductionRecipe(Instruction &I,
+                           iterator_range<IterT> Operands) : VPRecipeWithIRFlags(
+    VPDef::VPPartialReductionSC, Operands, I), Opcode(I.getOpcode())
+  {}
+  ~VPPartialReductionRecipe() override = default;
+  VPPartialReductionRecipe *clone() override {
+    auto *R = new VPPartialReductionRecipe(*getUnderlyingInstr(), operands());
+    R->transferFlags(*this);
+    return R;
+  }
+  VP_CLASSOF_IMPL(VPDef::VPPartialReductionSC)
+  /// Generate the reduction in the loop
+  void execute(VPTransformState &State) override;
+  void postInsertionOp() override;
+  unsigned getOpcode() { return Opcode; }
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+  /// Print the recipe.
+  void print(raw_ostream &O, const Twine &Indent,
+             VPSlotTracker &SlotTracker) const override;
+#endif
+};
+
 /// A recipe for vectorizing a phi-node as a sequence of mask-based select
 /// instructions.
 class VPBlendRecipe : public VPSingleDefRecipe {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
index 5f93339083f0c..8a75668886599 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
@@ -208,6 +208,10 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPReplicateRecipe *R) {
   llvm_unreachable("Unhandled opcode");
 }
 
+Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPPartialReductionRecipe *R) {
+  return R->getUnderlyingInstr()->getType();
+}
+
 Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
   if (Type *CachedTy = CachedTypes.lookup(V))
     return CachedTy;
@@ -238,7 +242,7 @@ Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
             return inferScalarType(R->getOperand(0));
           })
           .Case<VPBlendRecipe, VPInstruction, VPWidenRecipe, VPReplicateRecipe,
-                VPWidenCallRecipe, VPWidenMemoryRecipe, VPWidenSelectRecipe>(
+                VPWidenCallRecipe, VPWidenMemoryRecipe, VPWidenSelectRecipe, VPPartialReductionRecipe>(
               [this](const auto *R) { return inferScalarTypeForRecipe(R); })
           .Case<VPInterleaveRecipe>([V](const VPInterleaveRecipe *R) {
             // TODO: Use info from interleave group.
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
index 7d310b1b31b6f..3bd8d24542199 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
@@ -23,6 +23,7 @@ class VPWidenIntOrFpInductionRecipe;
 class VPWidenMemoryRecipe;
 struct VPWidenSelectRecipe;
 class VPReplicateRecipe;
+class VPPartialReductionRecipe;
 class Type;
 
 /// An analysis for type-inference for VPValues.
@@ -49,6 +50,7 @@ class VPTypeAnalysis {
   Type *inferScalarTypeForRecipe(const VPWidenMemoryRecipe *R);
   Type *inferScalarTypeForRecipe(const VPWidenSelectRecipe *R);
   Type *inferScalarTypeForRecipe(const VPReplicateRecipe *R);
+  Type *inferScalarTypeForRecipe(const VPPartialReductionRecipe *R);
 
 public:
   VPTypeAnalysis(Type *CanonicalIVTy, LLVMContext &Ctx)
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 9ec422ec002c8..9aff5dd0a7771 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -245,6 +245,76 @@ void VPRecipeBase::moveBefore(VPBasicBlock &BB,
   insertBefore(BB, I);
 }
 
+void VPPartialReductionRecipe::execute(VPTransformState &State) {
+  State.setDebugLocFrom(getDebugLoc());
+  auto &Builder = State.Builder;
+
+  switch(Opcode) {
+  case Instruction::Add: {
+
+    for (unsigned Part = 0; Part < State.UF; ++Part) {
+      Value* Mul = nullptr;
+      Value* Phi = nullptr;
+      SmallVector<Value*, 2> Ops;
+      for (VPValue *VPOp : operands()) {
+        auto *Op = State.get(VPOp, Part);
+        Ops.push_back(Op);
+        if(isa<PHINode>(Op))
+          Phi = Op;
+        else
+          Mul = Op;
+      }
+
+      assert(Phi && Mul && "Phi and Mul must be set");
+      assert(isa<ScalableVectorType>(Ops[0]->getType()) && "Type must be a scalable vector");
+
+      ScalableVectorType *FullTy = cast<ScalableVectorType>(Ops[0]->getType());
+      Type *RetTy = ScalableVectorType::get(FullTy->getScalarType(), 4);
+
+      Intrinsic:...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented May 16, 2024

@llvm/pr-subscribers-llvm-ir

Author: None (NickGuy-Arm)

Changes

This patch adds to the loop vectorizer support for partial reductions; that is a reduction from a wider vector to a narrower vector.


Patch is 29.88 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/92418.diff

12 Files Affected:

  • (modified) llvm/include/llvm/IR/DerivedTypes.h (+10)
  • (modified) llvm/include/llvm/IR/Intrinsics.h (+3-2)
  • (modified) llvm/include/llvm/IR/Intrinsics.td (+10)
  • (modified) llvm/lib/IR/Function.cpp (+16)
  • (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+122)
  • (modified) llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h (+2)
  • (modified) llvm/lib/Transforms/Vectorize/VPlan.h (+40-3)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp (+5-1)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanAnalysis.h (+2)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+76-4)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanValue.h (+1)
  • (added) llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll (+100)
diff --git a/llvm/include/llvm/IR/DerivedTypes.h b/llvm/include/llvm/IR/DerivedTypes.h
index 443fb7de3b821..866a01c9afebd 100644
--- a/llvm/include/llvm/IR/DerivedTypes.h
+++ b/llvm/include/llvm/IR/DerivedTypes.h
@@ -512,6 +512,16 @@ class VectorType : public Type {
                            EltCnt.divideCoefficientBy(2));
   }
 
+  /// This static method returns a VectorType with quarter as many elements as the
+  /// input type and the same element type.
+  static VectorType *getQuarterElementsVectorType(VectorType *VTy) {
+    auto EltCnt = VTy->getElementCount();
+    assert(EltCnt.isKnownEven() &&
+           "Cannot halve vector with odd number of elements.");
+    return VectorType::get(VTy->getElementType(),
+                           EltCnt.divideCoefficientBy(4));
+  }
+
   /// This static method returns a VectorType with twice as many elements as the
   /// input type and the same element type.
   static VectorType *getDoubleElementsVectorType(VectorType *VTy) {
diff --git a/llvm/include/llvm/IR/Intrinsics.h b/llvm/include/llvm/IR/Intrinsics.h
index 340c1c326d066..e03e7e0bf50de 100644
--- a/llvm/include/llvm/IR/Intrinsics.h
+++ b/llvm/include/llvm/IR/Intrinsics.h
@@ -131,6 +131,7 @@ namespace Intrinsic {
       ExtendArgument,
       TruncArgument,
       HalfVecArgument,
+      QuarterVecArgument,
       SameVecWidthArgument,
       VecOfAnyPtrsToElt,
       VecElementArgument,
@@ -160,7 +161,7 @@ namespace Intrinsic {
 
     unsigned getArgumentNumber() const {
       assert(Kind == Argument || Kind == ExtendArgument ||
-             Kind == TruncArgument || Kind == HalfVecArgument ||
+             Kind == TruncArgument || Kind == HalfVecArgument || Kind == QuarterVecArgument ||
              Kind == SameVecWidthArgument || Kind == VecElementArgument ||
              Kind == Subdivide2Argument || Kind == Subdivide4Argument ||
              Kind == VecOfBitcastsToInt);
@@ -168,7 +169,7 @@ namespace Intrinsic {
     }
     ArgKind getArgumentKind() const {
       assert(Kind == Argument || Kind == ExtendArgument ||
-             Kind == TruncArgument || Kind == HalfVecArgument ||
+             Kind == TruncArgument || Kind == HalfVecArgument || Kind == QuarterVecArgument ||
              Kind == SameVecWidthArgument ||
              Kind == VecElementArgument || Kind == Subdivide2Argument ||
              Kind == Subdivide4Argument || Kind == VecOfBitcastsToInt);
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index 1d20f7e1b1985..dad177e595341 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -321,6 +321,7 @@ def IIT_I4 : IIT_Int<4, 58>;
 def IIT_AARCH64_SVCOUNT : IIT_VT<aarch64svcount, 59>;
 def IIT_V6 : IIT_Vec<6, 60>;
 def IIT_V10 : IIT_Vec<10, 61>;
+def IIT_QUARTER_VEC_ARG : IIT_Base<62>;
 }
 
 defvar IIT_all_FixedTypes = !filter(iit, IIT_all,
@@ -457,6 +458,9 @@ class LLVMVectorElementType<int num> : LLVMMatchType<num, IIT_VEC_ELEMENT>;
 class LLVMHalfElementsVectorType<int num>
   : LLVMMatchType<num, IIT_HALF_VEC_ARG>;
 
+class LLVMQuarterElementsVectorType<int num>
+  : LLVMMatchType<num, IIT_QUARTER_VEC_ARG>;
+
 // Match the type of another intrinsic parameter that is expected to be a
 // vector type (i.e. <N x iM>) but with each element subdivided to
 // form a vector with more elements that are smaller than the original.
@@ -2605,6 +2609,12 @@ def int_experimental_vector_deinterleave2 : DefaultAttrsIntrinsic<[LLVMHalfEleme
                                                                   [llvm_anyvector_ty],
                                                                   [IntrNoMem]>;
 
+//===-------------- Intrinsics to perform partial reduction ---------------===//
+
+def int_experimental_vector_partial_reduce_add : DefaultAttrsIntrinsic<[LLVMQuarterElementsVectorType<0>],
+                                                                       [llvm_anyvector_ty],
+                                                                       [IntrNoMem]>;
+
 //===----------------- Pointer Authentication Intrinsics ------------------===//
 //
 
diff --git a/llvm/lib/IR/Function.cpp b/llvm/lib/IR/Function.cpp
index e66fe73425e86..e9eebd5e35300 100644
--- a/llvm/lib/IR/Function.cpp
+++ b/llvm/lib/IR/Function.cpp
@@ -1240,6 +1240,12 @@ static void DecodeIITType(unsigned &NextElt, ArrayRef<unsigned char> Infos,
                                              ArgInfo));
     return;
   }
+  case IIT_QUARTER_VEC_ARG: {
+    unsigned ArgInfo = (NextElt == Infos.size() ? 0 : Infos[NextElt++]);
+    OutputTable.push_back(IITDescriptor::get(IITDescriptor::QuarterVecArgument,
+                                             ArgInfo));
+    return;
+  }
   case IIT_SAME_VEC_WIDTH_ARG: {
     unsigned ArgInfo = (NextElt == Infos.size() ? 0 : Infos[NextElt++]);
     OutputTable.push_back(IITDescriptor::get(IITDescriptor::SameVecWidthArgument,
@@ -1404,6 +1410,9 @@ static Type *DecodeFixedType(ArrayRef<Intrinsic::IITDescriptor> &Infos,
   case IITDescriptor::HalfVecArgument:
     return VectorType::getHalfElementsVectorType(cast<VectorType>(
                                                   Tys[D.getArgumentNumber()]));
+  case IITDescriptor::QuarterVecArgument:  {
+    return VectorType::getQuarterElementsVectorType(cast<VectorType>(Tys[D.getArgumentNumber()]));
+  }
   case IITDescriptor::SameVecWidthArgument: {
     Type *EltTy = DecodeFixedType(Infos, Tys, Context);
     Type *Ty = Tys[D.getArgumentNumber()];
@@ -1619,6 +1628,13 @@ static bool matchIntrinsicType(
       return !isa<VectorType>(ArgTys[D.getArgumentNumber()]) ||
              VectorType::getHalfElementsVectorType(
                      cast<VectorType>(ArgTys[D.getArgumentNumber()])) != Ty;
+    case IITDescriptor::QuarterVecArgument: {
+    if (D.getArgumentNumber() >= ArgTys.size())
+        return IsDeferredCheck || DeferCheck(Ty);
+      return !isa<VectorType>(ArgTys[D.getArgumentNumber()]) ||
+             VectorType::getQuarterElementsVectorType(
+                     cast<VectorType>(ArgTys[D.getArgumentNumber()])) != Ty;
+    }
     case IITDescriptor::SameVecWidthArgument: {
       if (D.getArgumentNumber() >= ArgTys.size()) {
         // Defer check and subsequent check for the vector element type.
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 33c4decd58a6c..1f37df061bbf7 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -2203,6 +2203,92 @@ static bool useActiveLaneMaskForControlFlow(TailFoldingStyle Style) {
          Style == TailFoldingStyle::DataAndControlFlowWithoutRuntimeCheck;
 }
 
+static void getPartialReductionInstrChain(Instruction *Instr, SmallVector<Value*, 4> &Chain) {
+  Instruction *Mul = cast<Instruction>(Instr->getOperand(0));
+  Instruction *Ext0 = cast<ZExtInst>(Mul->getOperand(0));
+  Instruction *Ext1 = cast<ZExtInst>(Mul->getOperand(1));
+
+  Chain.push_back(Mul);
+  Chain.push_back(Ext0);
+  Chain.push_back(Ext1);
+  Chain.push_back(Instr->getOperand(1));
+}
+
+
+/// @param Instr The root instruction to scan
+static bool isInstrPartialReduction(Instruction *Instr) {
+  Value *ExpectedPhi;
+  Value *A, *B;
+  Value *InductionA, *InductionB;
+
+  using namespace llvm::PatternMatch;
+  auto Pattern = m_Add(
+    m_OneUse(m_Mul(
+      m_OneUse(m_ZExt(
+        m_OneUse(m_Load(
+          m_GEP(
+              m_Value(A),
+              m_Value(InductionA)))))),
+      m_OneUse(m_ZExt(
+        m_OneUse(m_Load(
+          m_GEP(
+              m_Value(B),
+              m_Value(InductionB))))))
+        )), m_Value(ExpectedPhi));
+
+  bool Matches = match(Instr, Pattern);
+
+  if(!Matches)
+    return false;
+
+  // Check that the two induction variable uses are to the same induction variable
+  if(InductionA != InductionB) {
+    LLVM_DEBUG(dbgs() << "Loop uses different induction variables for each input variable, cannot create a partial reduction.\n");
+    return false;
+  }
+
+  Instruction *Mul = cast<Instruction>(Instr->getOperand(0));
+  Instruction *Ext0 = cast<ZExtInst>(Mul->getOperand(0));
+  Instruction *Ext1 = cast<ZExtInst>(Mul->getOperand(1));
+
+  // Check that the extends extend to i32
+  if(!Ext0->getType()->isIntegerTy(32) || !Ext1->getType()->isIntegerTy(32)) {
+    LLVM_DEBUG(dbgs() << "Extends don't extend to the correct width, cannot create a partial reduction.\n");
+    return false;
+  }
+
+  // Check that the loads are loading i8
+  LoadInst *Load0 = cast<LoadInst>(Ext0->getOperand(0));
+  LoadInst *Load1 = cast<LoadInst>(Ext1->getOperand(0));
+  if(!Load0->getType()->isIntegerTy(8) || !Load1->getType()->isIntegerTy(8)) {
+    LLVM_DEBUG(dbgs() << "Loads don't load the correct width, cannot create a partial reduction\n");
+    return false;
+  }
+
+  // Check that the add feeds into ExpectedPhi
+  PHINode *PhiNode = dyn_cast<PHINode>(ExpectedPhi);
+  if(!PhiNode) {
+    LLVM_DEBUG(dbgs() << "Expected Phi node was not a phi, cannot create a partial reduction.\n");
+    return false;
+  }
+
+  // Check that the first phi value is a zero initializer
+  ConstantInt *ZeroInit = dyn_cast<ConstantInt>(PhiNode->getIncomingValue(0));
+  if(!ZeroInit || !ZeroInit->isZero()) {
+    LLVM_DEBUG(dbgs() << "First PHI value is not a constant zero, cannot create a partial reduction.\n");
+    return false;
+  }
+
+  // Check that the second phi value is the instruction we're looking at
+  Instruction *MaybeAdd = dyn_cast<Instruction>(PhiNode->getIncomingValue(1));
+  if(!MaybeAdd || MaybeAdd != Instr) {
+    LLVM_DEBUG(dbgs() << "Second PHI value is not the root add, cannot create a partial reduction.\n");
+    return false;
+  }
+
+  return true;
+}
+
 // Return true if \p OuterLp is an outer loop annotated with hints for explicit
 // vectorization. The loop needs to be annotated with #pragma omp simd
 // simdlen(#) or #pragma clang vectorize(enable) vectorize_width(#). If the
@@ -5084,6 +5170,13 @@ bool LoopVectorizationPlanner::isCandidateForEpilogueVectorization(
         return false;
   }
 
+  // Prevent epilogue vectorization if a partial reduction is involved
+  // TODO Is there a cleaner way to check this?
+  if(any_of(Legal->getReductionVars(), [&](const std::pair<PHINode *, RecurrenceDescriptor> &Reduction) {
+    return isInstrPartialReduction(Reduction.second.getLoopExitInstr());
+  }))
+    return false;
+
   // Epilogue vectorization code has not been auditted to ensure it handles
   // non-latch exits properly.  It may be fine, but it needs auditted and
   // tested.
@@ -7182,6 +7275,17 @@ void LoopVectorizationCostModel::collectValuesToIgnore() {
     const SmallVectorImpl<Instruction *> &Casts = IndDes.getCastInsts();
     VecValuesToIgnore.insert(Casts.begin(), Casts.end());
   }
+
+  // Ignore any values that we know will be flattened
+  for(auto Reduction : this->Legal->getReductionVars()) {
+    auto &Recurrence = Reduction.second;
+    if(isInstrPartialReduction(Recurrence.getLoopExitInstr())) {
+      SmallVector<Value*, 4> PartialReductionValues;
+      getPartialReductionInstrChain(Recurrence.getLoopExitInstr(), PartialReductionValues);
+      ValuesToIgnore.insert(PartialReductionValues.begin(), PartialReductionValues.end());
+      VecValuesToIgnore.insert(PartialReductionValues.begin(), PartialReductionValues.end());
+    }
+  }
 }
 
 void LoopVectorizationCostModel::collectInLoopReductions() {
@@ -8536,9 +8640,24 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
                                  *CI);
   }
 
+  if(auto *PartialReduce = tryToCreatePartialReduction(Range, Instr, Operands))
+    return PartialReduce;
+
   return tryToWiden(Instr, Operands, VPBB);
 }
 
+VPRecipeBase *VPRecipeBuilder::tryToCreatePartialReduction(
+    VFRange &Range, Instruction *Instr, ArrayRef<VPValue *> Operands) {
+
+  if(isInstrPartialReduction(Instr)) {
+    auto EC = ElementCount::getScalable(16);
+    if(std::find(Range.begin(), Range.end(), EC) == Range.end())
+      return nullptr;
+    return new VPPartialReductionRecipe(*Instr, make_range(Operands.begin(), Operands.end()));
+  }
+  return nullptr;
+}
+
 void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
                                                         ElementCount MaxVF) {
   assert(OrigLoop->isInnermost() && "Inner loop expected.");
@@ -8746,6 +8865,9 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
         VPBB->appendRecipe(Recipe);
     }
 
+    for(auto &Recipe : *VPBB)
+      Recipe.postInsertionOp();
+
     VPBlockUtils::insertBlockAfter(new VPBasicBlock(), VPBB);
     VPBB = cast<VPBasicBlock>(VPBB->getSingleSuccessor());
   }
diff --git a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
index b4c7ab02f928f..c439f221709e1 100644
--- a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
+++ b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
@@ -116,6 +116,8 @@ class VPRecipeBuilder {
                                        ArrayRef<VPValue *> Operands,
                                        VFRange &Range, VPBasicBlock *VPBB);
 
+  VPRecipeBase* tryToCreatePartialReduction(VFRange &Range, Instruction* Instr, ArrayRef<VPValue*> Operands);
+
   /// Set the recipe created for given ingredient.
   void setRecipe(Instruction *I, VPRecipeBase *R) {
     assert(!Ingredient2Recipe.contains(I) &&
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index c74329a0bcc4a..5a572ecb798d6 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -767,6 +767,8 @@ class VPRecipeBase : public ilist_node_with_parent<VPRecipeBase, VPBasicBlock>,
   /// \returns an iterator pointing to the element after the erased one
   iplist<VPRecipeBase>::iterator eraseFromParent();
 
+  virtual void postInsertionOp() {}
+
   /// Method to support type inquiry through isa, cast, and dyn_cast.
   static inline bool classof(const VPDef *D) {
     // All VPDefs are also VPRecipeBases.
@@ -1881,14 +1883,19 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe {
   /// The phi is part of an ordered reduction. Requires IsInLoop to be true.
   bool IsOrdered;
 
+  /// The amount that the VF should be divided by during ::execute
+  unsigned VFScaleFactor = 1;
+
 public:
+
   /// Create a new VPReductionPHIRecipe for the reduction \p Phi described by \p
   /// RdxDesc.
   VPReductionPHIRecipe(PHINode *Phi, const RecurrenceDescriptor &RdxDesc,
                        VPValue &Start, bool IsInLoop = false,
-                       bool IsOrdered = false)
+                       bool IsOrdered = false, unsigned VFScaleFactor = 1)
       : VPHeaderPHIRecipe(VPDef::VPReductionPHISC, Phi, &Start),
-        RdxDesc(RdxDesc), IsInLoop(IsInLoop), IsOrdered(IsOrdered) {
+        RdxDesc(RdxDesc), IsInLoop(IsInLoop), IsOrdered(IsOrdered),
+        VFScaleFactor(VFScaleFactor) {
     assert((!IsOrdered || IsInLoop) && "IsOrdered requires IsInLoop");
   }
 
@@ -1897,7 +1904,7 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe {
   VPReductionPHIRecipe *clone() override {
     auto *R =
         new VPReductionPHIRecipe(cast<PHINode>(getUnderlyingInstr()), RdxDesc,
-                                 *getOperand(0), IsInLoop, IsOrdered);
+                                 *getOperand(0), IsInLoop, IsOrdered, VFScaleFactor);
     R->addOperand(getBackedgeValue());
     return R;
   }
@@ -1908,6 +1915,10 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe {
     return R->getVPDefID() == VPDef::VPReductionPHISC;
   }
 
+  void SetVFScaleFactor(unsigned ScaleFactor) {
+    VFScaleFactor = ScaleFactor;
+  }
+
   /// Generate the phi/select nodes.
   void execute(VPTransformState &State) override;
 
@@ -1928,6 +1939,32 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe {
   bool isInLoop() const { return IsInLoop; }
 };
 
+class VPPartialReductionRecipe : public VPRecipeWithIRFlags {
+  unsigned Opcode;
+public:
+  template <typename IterT>
+  VPPartialReductionRecipe(Instruction &I,
+                           iterator_range<IterT> Operands) : VPRecipeWithIRFlags(
+    VPDef::VPPartialReductionSC, Operands, I), Opcode(I.getOpcode())
+  {}
+  ~VPPartialReductionRecipe() override = default;
+  VPPartialReductionRecipe *clone() override {
+    auto *R = new VPPartialReductionRecipe(*getUnderlyingInstr(), operands());
+    R->transferFlags(*this);
+    return R;
+  }
+  VP_CLASSOF_IMPL(VPDef::VPPartialReductionSC)
+  /// Generate the reduction in the loop
+  void execute(VPTransformState &State) override;
+  void postInsertionOp() override;
+  unsigned getOpcode() { return Opcode; }
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+  /// Print the recipe.
+  void print(raw_ostream &O, const Twine &Indent,
+             VPSlotTracker &SlotTracker) const override;
+#endif
+};
+
 /// A recipe for vectorizing a phi-node as a sequence of mask-based select
 /// instructions.
 class VPBlendRecipe : public VPSingleDefRecipe {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
index 5f93339083f0c..8a75668886599 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
@@ -208,6 +208,10 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPReplicateRecipe *R) {
   llvm_unreachable("Unhandled opcode");
 }
 
+Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPPartialReductionRecipe *R) {
+  return R->getUnderlyingInstr()->getType();
+}
+
 Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
   if (Type *CachedTy = CachedTypes.lookup(V))
     return CachedTy;
@@ -238,7 +242,7 @@ Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
             return inferScalarType(R->getOperand(0));
           })
           .Case<VPBlendRecipe, VPInstruction, VPWidenRecipe, VPReplicateRecipe,
-                VPWidenCallRecipe, VPWidenMemoryRecipe, VPWidenSelectRecipe>(
+                VPWidenCallRecipe, VPWidenMemoryRecipe, VPWidenSelectRecipe, VPPartialReductionRecipe>(
               [this](const auto *R) { return inferScalarTypeForRecipe(R); })
           .Case<VPInterleaveRecipe>([V](const VPInterleaveRecipe *R) {
             // TODO: Use info from interleave group.
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
index 7d310b1b31b6f..3bd8d24542199 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
@@ -23,6 +23,7 @@ class VPWidenIntOrFpInductionRecipe;
 class VPWidenMemoryRecipe;
 struct VPWidenSelectRecipe;
 class VPReplicateRecipe;
+class VPPartialReductionRecipe;
 class Type;
 
 /// An analysis for type-inference for VPValues.
@@ -49,6 +50,7 @@ class VPTypeAnalysis {
   Type *inferScalarTypeForRecipe(const VPWidenMemoryRecipe *R);
   Type *inferScalarTypeForRecipe(const VPWidenSelectRecipe *R);
   Type *inferScalarTypeForRecipe(const VPReplicateRecipe *R);
+  Type *inferScalarTypeForRecipe(const VPPartialReductionRecipe *R);
 
 public:
   VPTypeAnalysis(Type *CanonicalIVTy, LLVMContext &Ctx)
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 9ec422ec002c8..9aff5dd0a7771 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -245,6 +245,76 @@ void VPRecipeBase::moveBefore(VPBasicBlock &BB,
   insertBefore(BB, I);
 }
 
+void VPPartialReductionRecipe::execute(VPTransformState &State) {
+  State.setDebugLocFrom(getDebugLoc());
+  auto &Builder = State.Builder;
+
+  switch(Opcode) {
+  case Instruction::Add: {
+
+    for (unsigned Part = 0; Part < State.UF; ++Part) {
+      Value* Mul = nullptr;
+      Value* Phi = nullptr;
+      SmallVector<Value*, 2> Ops;
+      for (VPValue *VPOp : operands()) {
+        auto *Op = State.get(VPOp, Part);
+        Ops.push_back(Op);
+        if(isa<PHINode>(Op))
+          Phi = Op;
+        else
+          Mul = Op;
+      }
+
+      assert(Phi && Mul && "Phi and Mul must be set");
+      assert(isa<ScalableVectorType>(Ops[0]->getType()) && "Type must be a scalable vector");
+
+      ScalableVectorType *FullTy = cast<ScalableVectorType>(Ops[0]->getType());
+      Type *RetTy = ScalableVectorType::get(FullTy->getScalarType(), 4);
+
+      Intrinsic:...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented May 16, 2024

@llvm/pr-subscribers-llvm-transforms

Author: None (NickGuy-Arm)

Changes

This patch adds to the loop vectorizer support for partial reductions; that is a reduction from a wider vector to a narrower vector.


Patch is 29.88 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/92418.diff

12 Files Affected:

  • (modified) llvm/include/llvm/IR/DerivedTypes.h (+10)
  • (modified) llvm/include/llvm/IR/Intrinsics.h (+3-2)
  • (modified) llvm/include/llvm/IR/Intrinsics.td (+10)
  • (modified) llvm/lib/IR/Function.cpp (+16)
  • (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+122)
  • (modified) llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h (+2)
  • (modified) llvm/lib/Transforms/Vectorize/VPlan.h (+40-3)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp (+5-1)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanAnalysis.h (+2)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+76-4)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanValue.h (+1)
  • (added) llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll (+100)
diff --git a/llvm/include/llvm/IR/DerivedTypes.h b/llvm/include/llvm/IR/DerivedTypes.h
index 443fb7de3b821..866a01c9afebd 100644
--- a/llvm/include/llvm/IR/DerivedTypes.h
+++ b/llvm/include/llvm/IR/DerivedTypes.h
@@ -512,6 +512,16 @@ class VectorType : public Type {
                            EltCnt.divideCoefficientBy(2));
   }
 
+  /// This static method returns a VectorType with quarter as many elements as the
+  /// input type and the same element type.
+  static VectorType *getQuarterElementsVectorType(VectorType *VTy) {
+    auto EltCnt = VTy->getElementCount();
+    assert(EltCnt.isKnownEven() &&
+           "Cannot halve vector with odd number of elements.");
+    return VectorType::get(VTy->getElementType(),
+                           EltCnt.divideCoefficientBy(4));
+  }
+
   /// This static method returns a VectorType with twice as many elements as the
   /// input type and the same element type.
   static VectorType *getDoubleElementsVectorType(VectorType *VTy) {
diff --git a/llvm/include/llvm/IR/Intrinsics.h b/llvm/include/llvm/IR/Intrinsics.h
index 340c1c326d066..e03e7e0bf50de 100644
--- a/llvm/include/llvm/IR/Intrinsics.h
+++ b/llvm/include/llvm/IR/Intrinsics.h
@@ -131,6 +131,7 @@ namespace Intrinsic {
       ExtendArgument,
       TruncArgument,
       HalfVecArgument,
+      QuarterVecArgument,
       SameVecWidthArgument,
       VecOfAnyPtrsToElt,
       VecElementArgument,
@@ -160,7 +161,7 @@ namespace Intrinsic {
 
     unsigned getArgumentNumber() const {
       assert(Kind == Argument || Kind == ExtendArgument ||
-             Kind == TruncArgument || Kind == HalfVecArgument ||
+             Kind == TruncArgument || Kind == HalfVecArgument || Kind == QuarterVecArgument ||
              Kind == SameVecWidthArgument || Kind == VecElementArgument ||
              Kind == Subdivide2Argument || Kind == Subdivide4Argument ||
              Kind == VecOfBitcastsToInt);
@@ -168,7 +169,7 @@ namespace Intrinsic {
     }
     ArgKind getArgumentKind() const {
       assert(Kind == Argument || Kind == ExtendArgument ||
-             Kind == TruncArgument || Kind == HalfVecArgument ||
+             Kind == TruncArgument || Kind == HalfVecArgument || Kind == QuarterVecArgument ||
              Kind == SameVecWidthArgument ||
              Kind == VecElementArgument || Kind == Subdivide2Argument ||
              Kind == Subdivide4Argument || Kind == VecOfBitcastsToInt);
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index 1d20f7e1b1985..dad177e595341 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -321,6 +321,7 @@ def IIT_I4 : IIT_Int<4, 58>;
 def IIT_AARCH64_SVCOUNT : IIT_VT<aarch64svcount, 59>;
 def IIT_V6 : IIT_Vec<6, 60>;
 def IIT_V10 : IIT_Vec<10, 61>;
+def IIT_QUARTER_VEC_ARG : IIT_Base<62>;
 }
 
 defvar IIT_all_FixedTypes = !filter(iit, IIT_all,
@@ -457,6 +458,9 @@ class LLVMVectorElementType<int num> : LLVMMatchType<num, IIT_VEC_ELEMENT>;
 class LLVMHalfElementsVectorType<int num>
   : LLVMMatchType<num, IIT_HALF_VEC_ARG>;
 
+class LLVMQuarterElementsVectorType<int num>
+  : LLVMMatchType<num, IIT_QUARTER_VEC_ARG>;
+
 // Match the type of another intrinsic parameter that is expected to be a
 // vector type (i.e. <N x iM>) but with each element subdivided to
 // form a vector with more elements that are smaller than the original.
@@ -2605,6 +2609,12 @@ def int_experimental_vector_deinterleave2 : DefaultAttrsIntrinsic<[LLVMHalfEleme
                                                                   [llvm_anyvector_ty],
                                                                   [IntrNoMem]>;
 
+//===-------------- Intrinsics to perform partial reduction ---------------===//
+
+def int_experimental_vector_partial_reduce_add : DefaultAttrsIntrinsic<[LLVMQuarterElementsVectorType<0>],
+                                                                       [llvm_anyvector_ty],
+                                                                       [IntrNoMem]>;
+
 //===----------------- Pointer Authentication Intrinsics ------------------===//
 //
 
diff --git a/llvm/lib/IR/Function.cpp b/llvm/lib/IR/Function.cpp
index e66fe73425e86..e9eebd5e35300 100644
--- a/llvm/lib/IR/Function.cpp
+++ b/llvm/lib/IR/Function.cpp
@@ -1240,6 +1240,12 @@ static void DecodeIITType(unsigned &NextElt, ArrayRef<unsigned char> Infos,
                                              ArgInfo));
     return;
   }
+  case IIT_QUARTER_VEC_ARG: {
+    unsigned ArgInfo = (NextElt == Infos.size() ? 0 : Infos[NextElt++]);
+    OutputTable.push_back(IITDescriptor::get(IITDescriptor::QuarterVecArgument,
+                                             ArgInfo));
+    return;
+  }
   case IIT_SAME_VEC_WIDTH_ARG: {
     unsigned ArgInfo = (NextElt == Infos.size() ? 0 : Infos[NextElt++]);
     OutputTable.push_back(IITDescriptor::get(IITDescriptor::SameVecWidthArgument,
@@ -1404,6 +1410,9 @@ static Type *DecodeFixedType(ArrayRef<Intrinsic::IITDescriptor> &Infos,
   case IITDescriptor::HalfVecArgument:
     return VectorType::getHalfElementsVectorType(cast<VectorType>(
                                                   Tys[D.getArgumentNumber()]));
+  case IITDescriptor::QuarterVecArgument:  {
+    return VectorType::getQuarterElementsVectorType(cast<VectorType>(Tys[D.getArgumentNumber()]));
+  }
   case IITDescriptor::SameVecWidthArgument: {
     Type *EltTy = DecodeFixedType(Infos, Tys, Context);
     Type *Ty = Tys[D.getArgumentNumber()];
@@ -1619,6 +1628,13 @@ static bool matchIntrinsicType(
       return !isa<VectorType>(ArgTys[D.getArgumentNumber()]) ||
              VectorType::getHalfElementsVectorType(
                      cast<VectorType>(ArgTys[D.getArgumentNumber()])) != Ty;
+    case IITDescriptor::QuarterVecArgument: {
+    if (D.getArgumentNumber() >= ArgTys.size())
+        return IsDeferredCheck || DeferCheck(Ty);
+      return !isa<VectorType>(ArgTys[D.getArgumentNumber()]) ||
+             VectorType::getQuarterElementsVectorType(
+                     cast<VectorType>(ArgTys[D.getArgumentNumber()])) != Ty;
+    }
     case IITDescriptor::SameVecWidthArgument: {
       if (D.getArgumentNumber() >= ArgTys.size()) {
         // Defer check and subsequent check for the vector element type.
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 33c4decd58a6c..1f37df061bbf7 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -2203,6 +2203,92 @@ static bool useActiveLaneMaskForControlFlow(TailFoldingStyle Style) {
          Style == TailFoldingStyle::DataAndControlFlowWithoutRuntimeCheck;
 }
 
+static void getPartialReductionInstrChain(Instruction *Instr, SmallVector<Value*, 4> &Chain) {
+  Instruction *Mul = cast<Instruction>(Instr->getOperand(0));
+  Instruction *Ext0 = cast<ZExtInst>(Mul->getOperand(0));
+  Instruction *Ext1 = cast<ZExtInst>(Mul->getOperand(1));
+
+  Chain.push_back(Mul);
+  Chain.push_back(Ext0);
+  Chain.push_back(Ext1);
+  Chain.push_back(Instr->getOperand(1));
+}
+
+
+/// @param Instr The root instruction to scan
+static bool isInstrPartialReduction(Instruction *Instr) {
+  Value *ExpectedPhi;
+  Value *A, *B;
+  Value *InductionA, *InductionB;
+
+  using namespace llvm::PatternMatch;
+  auto Pattern = m_Add(
+    m_OneUse(m_Mul(
+      m_OneUse(m_ZExt(
+        m_OneUse(m_Load(
+          m_GEP(
+              m_Value(A),
+              m_Value(InductionA)))))),
+      m_OneUse(m_ZExt(
+        m_OneUse(m_Load(
+          m_GEP(
+              m_Value(B),
+              m_Value(InductionB))))))
+        )), m_Value(ExpectedPhi));
+
+  bool Matches = match(Instr, Pattern);
+
+  if(!Matches)
+    return false;
+
+  // Check that the two induction variable uses are to the same induction variable
+  if(InductionA != InductionB) {
+    LLVM_DEBUG(dbgs() << "Loop uses different induction variables for each input variable, cannot create a partial reduction.\n");
+    return false;
+  }
+
+  Instruction *Mul = cast<Instruction>(Instr->getOperand(0));
+  Instruction *Ext0 = cast<ZExtInst>(Mul->getOperand(0));
+  Instruction *Ext1 = cast<ZExtInst>(Mul->getOperand(1));
+
+  // Check that the extends extend to i32
+  if(!Ext0->getType()->isIntegerTy(32) || !Ext1->getType()->isIntegerTy(32)) {
+    LLVM_DEBUG(dbgs() << "Extends don't extend to the correct width, cannot create a partial reduction.\n");
+    return false;
+  }
+
+  // Check that the loads are loading i8
+  LoadInst *Load0 = cast<LoadInst>(Ext0->getOperand(0));
+  LoadInst *Load1 = cast<LoadInst>(Ext1->getOperand(0));
+  if(!Load0->getType()->isIntegerTy(8) || !Load1->getType()->isIntegerTy(8)) {
+    LLVM_DEBUG(dbgs() << "Loads don't load the correct width, cannot create a partial reduction\n");
+    return false;
+  }
+
+  // Check that the add feeds into ExpectedPhi
+  PHINode *PhiNode = dyn_cast<PHINode>(ExpectedPhi);
+  if(!PhiNode) {
+    LLVM_DEBUG(dbgs() << "Expected Phi node was not a phi, cannot create a partial reduction.\n");
+    return false;
+  }
+
+  // Check that the first phi value is a zero initializer
+  ConstantInt *ZeroInit = dyn_cast<ConstantInt>(PhiNode->getIncomingValue(0));
+  if(!ZeroInit || !ZeroInit->isZero()) {
+    LLVM_DEBUG(dbgs() << "First PHI value is not a constant zero, cannot create a partial reduction.\n");
+    return false;
+  }
+
+  // Check that the second phi value is the instruction we're looking at
+  Instruction *MaybeAdd = dyn_cast<Instruction>(PhiNode->getIncomingValue(1));
+  if(!MaybeAdd || MaybeAdd != Instr) {
+    LLVM_DEBUG(dbgs() << "Second PHI value is not the root add, cannot create a partial reduction.\n");
+    return false;
+  }
+
+  return true;
+}
+
 // Return true if \p OuterLp is an outer loop annotated with hints for explicit
 // vectorization. The loop needs to be annotated with #pragma omp simd
 // simdlen(#) or #pragma clang vectorize(enable) vectorize_width(#). If the
@@ -5084,6 +5170,13 @@ bool LoopVectorizationPlanner::isCandidateForEpilogueVectorization(
         return false;
   }
 
+  // Prevent epilogue vectorization if a partial reduction is involved
+  // TODO Is there a cleaner way to check this?
+  if(any_of(Legal->getReductionVars(), [&](const std::pair<PHINode *, RecurrenceDescriptor> &Reduction) {
+    return isInstrPartialReduction(Reduction.second.getLoopExitInstr());
+  }))
+    return false;
+
   // Epilogue vectorization code has not been auditted to ensure it handles
   // non-latch exits properly.  It may be fine, but it needs auditted and
   // tested.
@@ -7182,6 +7275,17 @@ void LoopVectorizationCostModel::collectValuesToIgnore() {
     const SmallVectorImpl<Instruction *> &Casts = IndDes.getCastInsts();
     VecValuesToIgnore.insert(Casts.begin(), Casts.end());
   }
+
+  // Ignore any values that we know will be flattened
+  for(auto Reduction : this->Legal->getReductionVars()) {
+    auto &Recurrence = Reduction.second;
+    if(isInstrPartialReduction(Recurrence.getLoopExitInstr())) {
+      SmallVector<Value*, 4> PartialReductionValues;
+      getPartialReductionInstrChain(Recurrence.getLoopExitInstr(), PartialReductionValues);
+      ValuesToIgnore.insert(PartialReductionValues.begin(), PartialReductionValues.end());
+      VecValuesToIgnore.insert(PartialReductionValues.begin(), PartialReductionValues.end());
+    }
+  }
 }
 
 void LoopVectorizationCostModel::collectInLoopReductions() {
@@ -8536,9 +8640,24 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
                                  *CI);
   }
 
+  if(auto *PartialReduce = tryToCreatePartialReduction(Range, Instr, Operands))
+    return PartialReduce;
+
   return tryToWiden(Instr, Operands, VPBB);
 }
 
+VPRecipeBase *VPRecipeBuilder::tryToCreatePartialReduction(
+    VFRange &Range, Instruction *Instr, ArrayRef<VPValue *> Operands) {
+
+  if(isInstrPartialReduction(Instr)) {
+    auto EC = ElementCount::getScalable(16);
+    if(std::find(Range.begin(), Range.end(), EC) == Range.end())
+      return nullptr;
+    return new VPPartialReductionRecipe(*Instr, make_range(Operands.begin(), Operands.end()));
+  }
+  return nullptr;
+}
+
 void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
                                                         ElementCount MaxVF) {
   assert(OrigLoop->isInnermost() && "Inner loop expected.");
@@ -8746,6 +8865,9 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
         VPBB->appendRecipe(Recipe);
     }
 
+    for(auto &Recipe : *VPBB)
+      Recipe.postInsertionOp();
+
     VPBlockUtils::insertBlockAfter(new VPBasicBlock(), VPBB);
     VPBB = cast<VPBasicBlock>(VPBB->getSingleSuccessor());
   }
diff --git a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
index b4c7ab02f928f..c439f221709e1 100644
--- a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
+++ b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
@@ -116,6 +116,8 @@ class VPRecipeBuilder {
                                        ArrayRef<VPValue *> Operands,
                                        VFRange &Range, VPBasicBlock *VPBB);
 
+  VPRecipeBase* tryToCreatePartialReduction(VFRange &Range, Instruction* Instr, ArrayRef<VPValue*> Operands);
+
   /// Set the recipe created for given ingredient.
   void setRecipe(Instruction *I, VPRecipeBase *R) {
     assert(!Ingredient2Recipe.contains(I) &&
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index c74329a0bcc4a..5a572ecb798d6 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -767,6 +767,8 @@ class VPRecipeBase : public ilist_node_with_parent<VPRecipeBase, VPBasicBlock>,
   /// \returns an iterator pointing to the element after the erased one
   iplist<VPRecipeBase>::iterator eraseFromParent();
 
+  virtual void postInsertionOp() {}
+
   /// Method to support type inquiry through isa, cast, and dyn_cast.
   static inline bool classof(const VPDef *D) {
     // All VPDefs are also VPRecipeBases.
@@ -1881,14 +1883,19 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe {
   /// The phi is part of an ordered reduction. Requires IsInLoop to be true.
   bool IsOrdered;
 
+  /// The amount that the VF should be divided by during ::execute
+  unsigned VFScaleFactor = 1;
+
 public:
+
   /// Create a new VPReductionPHIRecipe for the reduction \p Phi described by \p
   /// RdxDesc.
   VPReductionPHIRecipe(PHINode *Phi, const RecurrenceDescriptor &RdxDesc,
                        VPValue &Start, bool IsInLoop = false,
-                       bool IsOrdered = false)
+                       bool IsOrdered = false, unsigned VFScaleFactor = 1)
       : VPHeaderPHIRecipe(VPDef::VPReductionPHISC, Phi, &Start),
-        RdxDesc(RdxDesc), IsInLoop(IsInLoop), IsOrdered(IsOrdered) {
+        RdxDesc(RdxDesc), IsInLoop(IsInLoop), IsOrdered(IsOrdered),
+        VFScaleFactor(VFScaleFactor) {
     assert((!IsOrdered || IsInLoop) && "IsOrdered requires IsInLoop");
   }
 
@@ -1897,7 +1904,7 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe {
   VPReductionPHIRecipe *clone() override {
     auto *R =
         new VPReductionPHIRecipe(cast<PHINode>(getUnderlyingInstr()), RdxDesc,
-                                 *getOperand(0), IsInLoop, IsOrdered);
+                                 *getOperand(0), IsInLoop, IsOrdered, VFScaleFactor);
     R->addOperand(getBackedgeValue());
     return R;
   }
@@ -1908,6 +1915,10 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe {
     return R->getVPDefID() == VPDef::VPReductionPHISC;
   }
 
+  void SetVFScaleFactor(unsigned ScaleFactor) {
+    VFScaleFactor = ScaleFactor;
+  }
+
   /// Generate the phi/select nodes.
   void execute(VPTransformState &State) override;
 
@@ -1928,6 +1939,32 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe {
   bool isInLoop() const { return IsInLoop; }
 };
 
+class VPPartialReductionRecipe : public VPRecipeWithIRFlags {
+  unsigned Opcode;
+public:
+  template <typename IterT>
+  VPPartialReductionRecipe(Instruction &I,
+                           iterator_range<IterT> Operands) : VPRecipeWithIRFlags(
+    VPDef::VPPartialReductionSC, Operands, I), Opcode(I.getOpcode())
+  {}
+  ~VPPartialReductionRecipe() override = default;
+  VPPartialReductionRecipe *clone() override {
+    auto *R = new VPPartialReductionRecipe(*getUnderlyingInstr(), operands());
+    R->transferFlags(*this);
+    return R;
+  }
+  VP_CLASSOF_IMPL(VPDef::VPPartialReductionSC)
+  /// Generate the reduction in the loop
+  void execute(VPTransformState &State) override;
+  void postInsertionOp() override;
+  unsigned getOpcode() { return Opcode; }
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+  /// Print the recipe.
+  void print(raw_ostream &O, const Twine &Indent,
+             VPSlotTracker &SlotTracker) const override;
+#endif
+};
+
 /// A recipe for vectorizing a phi-node as a sequence of mask-based select
 /// instructions.
 class VPBlendRecipe : public VPSingleDefRecipe {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
index 5f93339083f0c..8a75668886599 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
@@ -208,6 +208,10 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPReplicateRecipe *R) {
   llvm_unreachable("Unhandled opcode");
 }
 
+Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPPartialReductionRecipe *R) {
+  return R->getUnderlyingInstr()->getType();
+}
+
 Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
   if (Type *CachedTy = CachedTypes.lookup(V))
     return CachedTy;
@@ -238,7 +242,7 @@ Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
             return inferScalarType(R->getOperand(0));
           })
           .Case<VPBlendRecipe, VPInstruction, VPWidenRecipe, VPReplicateRecipe,
-                VPWidenCallRecipe, VPWidenMemoryRecipe, VPWidenSelectRecipe>(
+                VPWidenCallRecipe, VPWidenMemoryRecipe, VPWidenSelectRecipe, VPPartialReductionRecipe>(
               [this](const auto *R) { return inferScalarTypeForRecipe(R); })
           .Case<VPInterleaveRecipe>([V](const VPInterleaveRecipe *R) {
             // TODO: Use info from interleave group.
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
index 7d310b1b31b6f..3bd8d24542199 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
@@ -23,6 +23,7 @@ class VPWidenIntOrFpInductionRecipe;
 class VPWidenMemoryRecipe;
 struct VPWidenSelectRecipe;
 class VPReplicateRecipe;
+class VPPartialReductionRecipe;
 class Type;
 
 /// An analysis for type-inference for VPValues.
@@ -49,6 +50,7 @@ class VPTypeAnalysis {
   Type *inferScalarTypeForRecipe(const VPWidenMemoryRecipe *R);
   Type *inferScalarTypeForRecipe(const VPWidenSelectRecipe *R);
   Type *inferScalarTypeForRecipe(const VPReplicateRecipe *R);
+  Type *inferScalarTypeForRecipe(const VPPartialReductionRecipe *R);
 
 public:
   VPTypeAnalysis(Type *CanonicalIVTy, LLVMContext &Ctx)
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 9ec422ec002c8..9aff5dd0a7771 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -245,6 +245,76 @@ void VPRecipeBase::moveBefore(VPBasicBlock &BB,
   insertBefore(BB, I);
 }
 
+void VPPartialReductionRecipe::execute(VPTransformState &State) {
+  State.setDebugLocFrom(getDebugLoc());
+  auto &Builder = State.Builder;
+
+  switch(Opcode) {
+  case Instruction::Add: {
+
+    for (unsigned Part = 0; Part < State.UF; ++Part) {
+      Value* Mul = nullptr;
+      Value* Phi = nullptr;
+      SmallVector<Value*, 2> Ops;
+      for (VPValue *VPOp : operands()) {
+        auto *Op = State.get(VPOp, Part);
+        Ops.push_back(Op);
+        if(isa<PHINode>(Op))
+          Phi = Op;
+        else
+          Mul = Op;
+      }
+
+      assert(Phi && Mul && "Phi and Mul must be set");
+      assert(isa<ScalableVectorType>(Ops[0]->getType()) && "Type must be a scalable vector");
+
+      ScalableVectorType *FullTy = cast<ScalableVectorType>(Ops[0]->getType());
+      Type *RetTy = ScalableVectorType::get(FullTy->getScalarType(), 4);
+
+      Intrinsic:...
[truncated]

@NickGuy-Arm
Copy link
Contributor Author

This patch only implements the pattern recognition and production of the partial reduction intrinsic, it does not yet lower the intrinsic to valid IR/Asm, those will be coming later.
I'm also away for the next week, so will address comments when I return

Copy link

github-actions bot commented May 16, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Collaborator

@davemgreen davemgreen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi - Sounds like a nice approach.

This patch only implements the pattern recognition and production of the partial reduction intrinsic, it does not yet lower the intrinsic to valid IR/Asm, those will be coming later.

They might need to come first, or at least be committed first. The intrinsics will need language ref which will need to be agreed upon, and some generic lowering.

@@ -512,6 +512,16 @@ class VectorType : public Type {
EltCnt.divideCoefficientBy(2));
}

/// This static method returns a VectorType with quarter as many elements as the
/// input type and the same element type.
static VectorType *getQuarterElementsVectorType(VectorType *VTy) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be more generic than just 4x wider. I believe an ADDP would be a 2 x wider partial reduction for example. The input type needs to be a multiple of the output type, and it might be easier to keep it to a power-2 factor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've removed the restriction of only being 4x, opting instead for any vector type being valid and having the restrictions be defined by whatever emits the intrinsic (In this case, the Loop Vectorizer)

Copy link
Collaborator

@efriedma-quic efriedma-quic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm understanding correctly, a "partial reduction" is just a slightly different way of generating code for a reduction? Basically, instead of performing the reduction using a number of lanes equal to the vector factor, you combine some of the lanes each iteration. Usually, this wouldn't really be profitable unless you have a register pressure problem. But in very specific cases, you can use specialized instructions that do horizontal sums, in which case it's extremely profitable. (This is why the testcase is called "partial-reduced-sdot.ll", I assume.)

It seems a bit weird to me to introduce a new intrinsic that, in the general case, isn't actually a natively supported operation on any target.

llvm/test/CodeGen/AArch64/partial-reduce-sdot.ll Outdated Show resolved Hide resolved
@huntergr-arm
Copy link
Collaborator

huntergr-arm commented May 17, 2024

If I'm understanding correctly, a "partial reduction" is just a slightly different way of generating code for a reduction? Basically, instead of performing the reduction using a number of lanes equal to the vector factor, you combine some of the lanes each iteration. Usually, this wouldn't really be profitable unless you have a register pressure problem. But in very specific cases, you can use specialized instructions that do horizontal sums, in which case it's extremely profitable. (This is why the testcase is called "partial-reduced-sdot.ll", I assume.)

It seems a bit weird to me to introduce a new intrinsic that, in the general case, isn't actually a natively supported operation on any target.

Hi,

Yes, it's effectively a way of representing a reduction that allows us to vectorize with a wider VF than we normally would, since the IR extends the elements loaded from memory. For the AArch64 instructions we're targeting (sdot, udot, etc.) the extension is part of the instruction; e.g. sdot of two <vscale x 16 x i8> inputs results in a <vscale x 4 x i32> output. While this may be interesting for some actual dot products in SLP vectorization, for this patch we're just interesting in increasing our VF where possible.

I posted PRs last year for a different approach which only widened the VF in LoopVec and pattern-matched to aarch64-specific dot product instructions in a target-specific pass. There was no real interest in those PRs and I was asked to consider a different approach. Nick has now implemented the suggested approach.

(obsolete LoopVec PR to widen VF: #69587)
(obsolete AArch64-specific target pass PR to pattern match the resulting IR: #69583)

@paulwalker-arm 's RFC for the alternative: https://discourse.llvm.org/t/rfc-is-a-more-expressive-way-to-represent-reductions-useful/74929

@paulwalker-arm
Copy link
Collaborator

It seems a bit weird to me to introduce a new intrinsic that, in the general case, isn't actually a natively supported operation on any target.

I see it more about giving LLVM IR a more powerful representation of reductions than we have today. The current representation effectively demands a specific order in which elements are reduced that is hard to break down (as can be seen with Graham's original patches).

By dissociating input and output types we can make VF decisions that better reflect the input data whilst at the same time express there is no defined ordering for how the inputs are reduced. For AArch64 specifically I'm hoping this goes beyond just dot instructions and allow us to make better use of paired and top-bottom instructions. I'd expect targets that have no special instructions to simply select the output type to match the input and then code generate a standard binop as they do today.

Perhaps there's an argument the new intrinsics can replace the current vector_reduce_ ones which are another special case being they have a single element result.

@NickGuy-Arm
Copy link
Contributor Author

I've separated out the recent work into logical chunks that, while conceptually could be separate PRs, are still somewhat inter-dependent and are untested in isolation. I could separate them out to different PRs if necessary, however I feel there is value in not fragmenting any discussions.

@paulwalker-arm
Copy link
Collaborator

I could separate them out to different PRs if necessary, however I feel there is value in not fragmenting any discussions.

As a minimum the intrinsic and its code generation should be broken out into its own PR. There's never a good reason for code generation and IR optimisation work to be combined because the intrinsic should be able to stand on its own merits.

@NickGuy-Arm
Copy link
Contributor Author

I've pulled the intrinsic & it's codegen out to #94499, I'll remove the relevant changes from this PR (once I figure out how to emulate PR dependencies)

Copy link
Collaborator

@SamTebbs33 SamTebbs33 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this looks really clean! Happy for it to land as-is.

Copy link
Contributor

@fhahn fhahn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAICT this isn't driven by cost at all? Could this be done as VPlan-to-VPlan transform that replaces regular reduction recipes with partial ones?

@@ -784,6 +784,8 @@ class VPRecipeBase : public ilist_node_with_parent<VPRecipeBase, VPBasicBlock>,
/// \returns an iterator pointing to the element after the erased one
iplist<VPRecipeBase>::iterator eraseFromParent();

virtual void postInsertionOp() {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this used for, needs doc-comment

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't believe this is necessary, or desired. It modifies another VPRecipe (the reduction phi) after all instructions have a defined recipe. If my idea about storing information in the cost model is used, then VPRecipeBuilder will have the necessary information at the time the initial reduction phi recipe is created.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -1915,23 +1917,27 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe {
/// The phi is part of an ordered reduction. Requires IsInLoop to be true.
bool IsOrdered;

/// The amount that the VF should be divided by during ::execute
unsigned VFScaleFactor = 1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be explained better?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Opcode(I.getOpcode()), Scale(Scale) {}
~VPPartialReductionRecipe() override = default;
VPPartialReductionRecipe *clone() override {
auto *R =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only used for epilogue vectorization, should be unreachable if not supported/tested yet

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -1962,6 +1970,35 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe {
bool isInLoop() const { return IsInLoop; }
};

class VPPartialReductionRecipe : public VPRecipeWithIRFlags {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is a VPSingleDef recipe, VPSingleDefRecipe::classof needs to be updated

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -1962,6 +1970,35 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe {
bool isInLoop() const { return IsInLoop; }
};

class VPPartialReductionRecipe : public VPRecipeWithIRFlags {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

void VPPartialReductionRecipe::print(raw_ostream &O, const Twine &Indent,
VPSlotTracker &SlotTracker) const {
O << Indent << "PARTIAL-REDUCE ";
printAsOperand(O, SlotTracker);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needs printing test

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a printing-level test but it seems like the recipe is executed before printing happens, so we instead see the computer-reduction-result and such.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Graham's suggested changes have actually caused the recipe to be printed at the proper time.

@@ -0,0 +1,96 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
; RUN: opt -passes=loop-vectorize -force-vector-interleave=1 -S < %s | FileCheck %s
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tests for LV should be in llvm/test/Transforms/LoopVectorize

Also, does this need negative tests?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -7978,6 +7978,12 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
return;
}
case Intrinsic::experimental_vector_partial_reduce_add: {

if(!TLI.shouldExpandPartialReductionIntrinsic(&I))) {
visitTargetIntrinsic(I, Intrinsic);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs codegen test?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is always being expanded at the moment. A future PR will enable lowering of partial reductions for Arm.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, but does this change need to be part of the LV changes? And would it be possible test this change separately?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be possible to remove this change from this PR so I'll give that a go.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second thought, removing this change from this PR would mean adding it to the AArch64 lowering PR which is not the cleanest, or creating a trivially small patch that adds the target hook, which I don't think is worth it, so I think this needs to stay in here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if I've misunderstood but the TargetLowering changes only relate to code generation whereas the meat of this patch relates to a LoopVectorize transformation.

To me moving the TargetLowering change into the PR that exercise the change (currently this test is dead code with visitTargetIntrinsic being unreachable) into the PR that implements the visitTargetIntrinsic side of the branch is the way to go.

Copy link
Collaborator

@SamTebbs33 SamTebbs33 Jul 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense to me, I've removed it now and will fold it into the AArch64 codegen PR.

@@ -293,6 +293,19 @@ struct FixedScalableVFPair {
bool hasVector() const { return FixedVF.isVector() || ScalableVF.isVector(); }
};

struct PartialReductionChain {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be moved to the CostModel class; the only reason it's outside is to pass a Chain object to VPRecipeBuilder::tryToCreatePartialReduction, and that only requires the reduction and scalefactor, so you could just pass those as direct parameters when trying to create the recipe.

The partial reduction recipe probably doesn't even need scalefactor since the input operands will already have the necessary types; it's just the reduction phi recipe that needs it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -331,6 +344,8 @@ class LoopVectorizationPlanner {
/// Profitable vector factors.
SmallVector<VectorizationFactor, 8> ProfitableVFs;

SmallVector<PartialReductionChain> PartialReductionChains;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed in the scope of LVP or can be limited to constructing the VPlans?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this is cleaner than passing it between functions with parameters. And the functions that use it are all part of LoopVectorizationPlanner anyway.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be moved to the cost model, which is available to both the LoopVectorizationPlanner and the VPRecipeBuilder. It can be handled similarly to CallInsts, with a map between the relevant instruction (loop exit instr from the recurrence descriptor?) combined with VF, and a struct containing the relevant details. This also means you don't have to pass the Planner into the cost model.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I couldn't include VF in the map since it's not available at the point of creating the recipe and it would be needed to fetch the chain from the map.

Copy link
Collaborator

@SamTebbs33 SamTebbs33 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for having a look Florian. I'll be taking over this PR while Nick is away.

@@ -1962,6 +1970,35 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe {
bool isInLoop() const { return IsInLoop; }
};

class VPPartialReductionRecipe : public VPRecipeWithIRFlags {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

void VPPartialReductionRecipe::print(raw_ostream &O, const Twine &Indent,
VPSlotTracker &SlotTracker) const {
O << Indent << "PARTIAL-REDUCE ";
printAsOperand(O, SlotTracker);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a printing-level test but it seems like the recipe is executed before printing happens, so we instead see the computer-reduction-result and such.

@@ -0,0 +1,96 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
; RUN: opt -passes=loop-vectorize -force-vector-interleave=1 -S < %s | FileCheck %s
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Opcode(I.getOpcode()), Scale(Scale) {}
~VPPartialReductionRecipe() override = default;
VPPartialReductionRecipe *clone() override {
auto *R =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -7978,6 +7978,12 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
return;
}
case Intrinsic::experimental_vector_partial_reduce_add: {

if(!TLI.shouldExpandPartialReductionIntrinsic(&I))) {
visitTargetIntrinsic(I, Intrinsic);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is always being expanded at the moment. A future PR will enable lowering of partial reductions for Arm.

@@ -331,6 +344,8 @@ class LoopVectorizationPlanner {
/// Profitable vector factors.
SmallVector<VectorizationFactor, 8> ProfitableVFs;

SmallVector<PartialReductionChain> PartialReductionChains;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this is cleaner than passing it between functions with parameters. And the functions that use it are all part of LoopVectorizationPlanner anyway.

/// Returns a struct containing the ratio between the two VFs and other cached
/// information, or null if no scalable reduction was found.
static std::optional<PartialReductionChain>
getScaledReduction(PHINode *PHI, const RecurrenceDescriptor &Rdx,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be moved to VPRecipeBuilder, thus not needing TTI/CM passed?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice idea, done.

if (std::optional<PartialReductionChain> Chain =
getScaledReduction(Phi, RdxDesc, &TTI, Range, CM))
PartialReductionChains.push_back(*Chain);
RecipeBuilder.addScaledReductionExitInstrs(PartialReductionChains);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would probably still be good to keep all logic to collect partial reductions in VPRecipeBuilder and have a single VPRecipeBuilder::collectScaledReductions if possible

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -34,6 +53,9 @@ class VPRecipeBuilder {
/// Target Library Info.
const TargetLibraryInfo *TLI;

// Target Transform Info
const TargetTransformInfo *TTI;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not used in the current version, but could if getScaledReduction is moved to builder?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it is indeed needed after moving getScaledReduction.

: std::make_optional(It->second);
}

void addScaledReductionExitInstrs(SmallVector<PartialReductionChain> Chains);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
void addScaledReductionExitInstrs(SmallVector<PartialReductionChain> Chains);
void addScaledReductionExitInstrs(ArrayRef<PartialReductionChain> Chains);

Avoid passing SmallVector by value, better to use ArrayRef or const SmallVector<> &

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessary now that collectScaledReductions exists.

Comment on lines 166 to 167
// Create and return a partial reduction recipe for a reduction instruction
// along with binary operation and reduction phi operands.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Create and return a partial reduction recipe for a reduction instruction
// along with binary operation and reduction phi operands.
/// Create and return a partial reduction recipe for a reduction instruction
/// along with binary operation and reduction phi operands.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -0,0 +1,69 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this test needed? test/Codegen/AArch64 should only contain code-gen related tests

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like it's a hold over from a previous version. Removed now.

@@ -0,0 +1,1203 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
; RUN: opt -passes=loop-vectorize -force-vector-interleave=1 -force-target-instruction-cost=1 -S < %s | FileCheck %s --check-prefixes=CHECK,CHECK-INTERLEAVE1
; RUN: opt -passes=loop-vectorize -force-target-instruction-cost=1 -S < %s | FileCheck %s --check-prefixes=CHECK,CHECK-INTERLEAVED
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should have at least some tests that run without -force-target-instruction-cost to exercise the default cost model path

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I've merged it with the max vector bandwidth line to reduce the number of extra run lines.

Copy link
Collaborator

@SamTebbs33 SamTebbs33 Nov 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't reply directly to your comment about the lack of scalable vector types, but the reason for them being missing is because we're not maximising the vector bandwidth. The new run line that enables max bandwidth shows scalable types. Removing the sve attribute does predictably prevent vectorisation altogether (likewise when replacing sve with neon).

; CHECK-INTERLEAVE1: middle.block:
; CHECK-INTERLEAVE1-NEXT: [[TMP27:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE1]])
; CHECK-INTERLEAVE1-NEXT: br i1 true, label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[VEC_EPILOG_ITER_CHECK:%.*]]
; CHECK-INTERLEAVE1: vec.epilog.iter.check:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It also looks like we generated epilogue vector loops for all cases, which makes the checks bigger than necessary. Could you disable epilogue vectorization for most tests, possibly having a single one that checks the epilogue code path (probably needs a separate file to test epilogue vectorization) ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, do let me know if I should add more tests to the epilogue vectorisation file.

%exitcond.not = icmp eq i64 %indvars.iv.next, 0
br i1 %exitcond.not, label %for.cond.cleanup.loopexit, label %for.body
}
attributes #0 = { nofree norecurse nosync nounwind memory(argmem: readwrite) uwtable vscale_range(1,16) "target-features"="+sve" }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like all tests that emit llvm.experimental.vector.partial.reduce.add only use fixed width variants. Is this expected? If so, can the +sve target feature be dropped? (If both scalable and fixed variants are expected, we should have tests with and without +sve

Also, none of nofree norecurse nosync nounwind memory(argmem: readwrite) uwtable

; CHECK-NEXT: WIDEN-REDUCTION-PHI ir<[[ACC:%.+]]> = phi ir<0>, ir<%add>
; CHECK-NEXT: vp<[[STEPS:%.+]]> = SCALAR-STEPS vp<[[CAN_IV]]>, ir<1>
; CHECK-NEXT: CLONE ir<%gep.a> = getelementptr ir<%a>, vp<[[STEPS]]>
; CHECK-NEXT: vp<%4> = vector-pointer ir<%gep.a>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be good to use pattern here and for other vp<> values

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@SamTebbs33
Copy link
Collaborator

If the patch looks good, I'd like to land this within the next couple of days.

@fhahn
Copy link
Contributor

fhahn commented Nov 13, 2024

If the patch looks good, I'd like to land this within the next couple of days.

Should be able to take another look tomorrow or on Friday

Comment on lines +2088 to +2089
virtual InstructionCost
getPartialReductionCost(unsigned Opcode, Type *InputType, Type *AccumType,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please document, this should include the definition of what partial reduction means in this context (possibly tying to the definition of the intrinsic?)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

enum PartialReductionExtendKind { PR_None, PR_SignExtend, PR_ZeroExtend };

static PartialReductionExtendKind
getPartialReductionExtendKind(Instruction *I) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to document and would probably be good to avoid pulling in Instructions.h here, possibly by moving the definition to TTIImpl?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be looking at the wrong version, but it looks like the code is unchanged?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I must have mis-marked this as done. Where do you actually want this moved to? I don't see any classes or structs called TTIImpl.

if (VF.isFixed() && !ST->isNeonAvailable() && !ST->hasDotProd())
return Invalid;

// FIXME: There should be a nicer way of doing this?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nicer way of doing what?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calculating the cost below, it's not great-looking code but I couldn't think of a cleaner way of doing it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like this comment could do with some elaboration, i.e. what is "this" in this case? What is the purpose of the below code, and why is it a FIXME? These sort of questions should be answered by the comment itself, rather than having to infer from surrounding cues or github comments to figure it out (or the FIXME comment removed if it's not valuable)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah probably good to be more explicit in the comment.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've just removed it.

@@ -8663,6 +8663,113 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I,
return Recipe;
}

void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
// Find all possible partial reductions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
// Find all possible partial reductions
// Find all possible partial reductions.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

// something that isn't another partial reduction. This is because the
// extends are intended to be lowered along with the reduction itself.

// Build up a set of partial reduction bin ops for efficient use checking
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
// Build up a set of partial reduction bin ops for efficient use checking
// Build up a set of partial reduction bin ops for efficient use checking.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Instruction *ExtendA;
Instruction *ExtendB;

Instruction *BinOp;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be good to document.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


/// The scaling factor between the size of the reduction type and the
/// (possibly extended) inputs.
unsigned ScaleFactor;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If only moved inside VPRecipeBuilder, might be good to move it inside?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've moved to to the partial reduction map.


VP_CLASSOF_IMPL(VPDef::VPPartialReductionSC)

/// Generate the reduction in the loop
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Suggested change
/// Generate the reduction in the loop
/// Generate the reduction in the loop.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

auto *ExtA = cast<Instruction>(BinOp->getOperand(0));
auto *ExtB = cast<Instruction>(BinOp->getOperand(1));
Value *A = ExtA->getOperand(0);
return Ctx.TTI.getPartialReductionCost(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the cost be computed by avoiding to rely in the underlying IR instructions, e.g. types should be determined by VPTypeAnalysis, extend kinds from the extend recipes. VPlan simplifications may change the recipe's operands, meaning the underlying IR isn't accurate any more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still pending?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, let me know how it looks now.

O << Indent << "PARTIAL-REDUCE ";
printAsOperand(O, SlotTracker);
O << " = " << Instruction::getOpcodeName(Opcode);
printFlags(O);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the recipe support flags? If so, they should be set during the recipe's ::execute as well

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor Author

@NickGuy-Arm NickGuy-Arm Nov 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the recipe support flags?

I don't think it does, no. With this PR as-is, the call to setFlags (introduced in 93fc7af) does nothing as the operation type of the recipe is OperationType::Other. This would however change with my suggestion of setting the underlying value via another constructor, so that would be something to consider (either adding proper support for the flags, or removing the setFlags and printFlags calls).

Additionally; Unless I'm missing something, CallInst doesn't support any of the flags that could be propagated via setFlags, do correct me on that if I'm wrong though.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer to remove setFlags and printFlags. Done that now.

if (VF.isFixed() && !ST->isNeonAvailable() && !ST->hasDotProd())
return Invalid;

// FIXME: There should be a nicer way of doing this?
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like this comment could do with some elaboration, i.e. what is "this" in this case? What is the purpose of the below code, and why is it a FIXME? These sort of questions should be answered by the comment itself, rather than having to infer from surrounding cues or github comments to figure it out (or the FIXME comment removed if it's not valuable)

};

// Check if each use of a chain's two extends is a partial reduction
// and only add those those that don't have non-partial reduction users.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Double "those"

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

Instruction *ExtendA;
Instruction *ExtendB;

/// The binary operation using the extends that is then reduced
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// The binary operation using the extends that is then reduced
/// The binary operation using the extends that is then reduced.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

}
}

/// Examines reduction operations to see if the target can use a cheaper
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this comment here if we have the same comment on the declaration? Can we remove this one?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -111,20 +139,46 @@ class VPRecipeBuilder {
VPHistogramRecipe *tryToWidenHistogram(const HistogramInfo *HI,
ArrayRef<VPValue *> Operands);

/// Examines reduction operations to see if the target can use a cheaper
/// operation with a wider per-iteration input VF and narrower PHI VF.
/// Returns a struct containing the ratio between the two VFs and other cached
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still accurate with the change to return a std::pair?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

Comment on lines 2392 to 2395
VPPartialReductionRecipe(unsigned ReductionOpcode,
iterator_range<IterT> Operands)
: VPRecipeWithIRFlags(VPDef::VPPartialReductionSC, Operands),
Opcode(ReductionOpcode) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we use one of the VPRecipeWithIRFlags constructors that set UnderlyingVal?

Suggested change
VPPartialReductionRecipe(unsigned ReductionOpcode,
iterator_range<IterT> Operands)
: VPRecipeWithIRFlags(VPDef::VPPartialReductionSC, Operands),
Opcode(ReductionOpcode) {
VPPartialReductionRecipe(Instruction &ReductionInst, iterator_range<IterT> Operands)
: VPRecipeWithIRFlags(VPDef::VPPartialReductionSC, Operands, ReductionInst), Opcode(ReductionInst.getOpcode())

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -24,6 +24,7 @@
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Analysis/IVDescriptors.h"
#include "llvm/Analysis/TargetTransformInfo.h"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this still be included?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

enum PartialReductionExtendKind { PR_None, PR_SignExtend, PR_ZeroExtend };

static PartialReductionExtendKind
getPartialReductionExtendKind(Instruction *I) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be looking at the wrong version, but it looks like the code is unchanged?

/// represented by the llvm.experimental.partial.reduce.add intrinsic, which
/// takes an accumulator and a binary operation operand that itself is fed by
/// two extends. An example of an operation that uses a partial reduction is a
/// dot product, which reduces a vector to another of 4 times larger but fewer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// dot product, which reduces a vector to another of 4 times larger but fewer
/// dot product on AArch64, which reduces a vector to another of 4 times larger but fewer

I think this is specific to AArch64?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessarily, a dot product is a general operation that goes 4 -> 1.

if (VF.isFixed() && !ST->isNeonAvailable() && !ST->hasDotProd())
return Invalid;

// FIXME: There should be a nicer way of doing this?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah probably good to be more explicit in the comment.


// Build up a set of partial reduction bin ops for efficient use checking.
SmallSet<User *, 4> PartialReductionBinOps;
for (auto It : PartialReductionChains) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for (auto It : PartialReductionChains) {
for (const auto &[_, PartialRdx] : PartialReductionChains) {

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

// Build up a set of partial reduction bin ops for efficient use checking.
SmallSet<User *, 4> PartialReductionBinOps;
for (auto It : PartialReductionChains) {
if (It.first.BinOp)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just double-checking there are cases where BinOp is nullptr?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There actually aren't any cases where it's null so I've removed this check.

Instruction *ExtA = cast<Instruction>(BinOp->getOperand(0));
Instruction *ExtB = cast<Instruction>(BinOp->getOperand(1));

// Check that the extends extend from the same type
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Check that the extends extend from the same type
// Check that the extends extend from the same type.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

VPValue *Phi = Operands[1];
if (isa<VPReductionPHIRecipe>(BinOp->getDefiningRecipe()))
std::swap(BinOp, Phi);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert(isa<VPReductionPHIRecipe>(Phi) & "...");

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is already an assertion in the recipe constructor, since that is the common place for initialising the recipes.

@@ -34,6 +54,9 @@ class VPRecipeBuilder {
/// Target Library Info.
const TargetLibraryInfo *TLI;

// Target Transform Info
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Target Transform Info
// Target Transform Info.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -63,6 +86,11 @@ class VPRecipeBuilder {
/// created.
SmallVector<VPHeaderPHIRecipe *, 4> PhisToFix;

/// The set of reduction exit instructions that will be scaled to
/// a smaller VF via partial reductions. paired with the scaling factor.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// a smaller VF via partial reductions. paired with the scaling factor.
/// a smaller VF via partial reductions, paired with the scaling factor.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

auto *ExtA = cast<Instruction>(BinOp->getOperand(0));
auto *ExtB = cast<Instruction>(BinOp->getOperand(1));
Value *A = ExtA->getOperand(0);
return Ctx.TTI.getPartialReductionCost(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still pending?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.