Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LVI] Support using block values when handling conditions #75311

Merged
merged 1 commit into from
Jan 2, 2024

Conversation

nikic
Copy link
Contributor

@nikic nikic commented Dec 13, 2023

Currently, LVI will only use conditions like "X < C" to constrain the value of X on the relevant edge. This patch extends it to handle conditions like "X < Y" by querying the known range of Y.

This means that getValueFromCondition() and various related APIs can now return nullopt to indicate that they have pushed to the worklist, and need to be called again later. This behavior is currently controlled by a UseBlockValue option, and only enabled for actual edge value handling. All other places deriving constraints from conditions keep using the previous logic for now.

This change was originally motivated as a fix for the regression reported in #73662 (comment). Unfortunately, it doesn't actually fix it, because we run into another issue there (LVI currently is really bad at handling values used in loops).

This change has some compile-time impact (http://llvm-compile-time-tracker.com/compare.php?from=41aa0d4690a25366a5acbd4f3cbc94ca89176dfe&to=a0d36310343898db0aea04f02682f6ff72d0b60f&stat=instructions:u), but it's fairly small, in the 0.05% range.

@llvmbot
Copy link
Member

llvmbot commented Dec 13, 2023

@llvm/pr-subscribers-clang

@llvm/pr-subscribers-llvm-transforms

Author: Nikita Popov (nikic)

Changes

Currently, LVI will only use conditions like "X < C" to constrain the value of X on the relevant edge. This patch extends it to handle conditions like "X < Y" by querying the known range of Y.

This means that getValueFromCondition() and various related APIs can now return nullopt to indicate that they have pushed to the worklist, and need to be called again later. This behavior is currently controlled by a UseBlockValue option, and only enabled for actual edge value handling. All other places deriving constraints from conditions keep using the previous logic for now.

This change was originally motivated as a fix for the regression reported in #73662 (comment). Unfortunately, it doesn't actually fix it, because we run into another issue there (LVI currently is really bad at handling values used in loops).

This change has some compile-time impact (http://llvm-compile-time-tracker.com/compare.php?from=ed2d497291f0de330e27109ce21375b41597b4a4&amp;to=a0d36310343898db0aea04f02682f6ff72d0b60f&amp;stat=instructions%3Au), but it's fairly small, in the 0.05% range.


Full diff: https://github.com/llvm/llvm-project/pull/75311.diff

2 Files Affected:

  • (modified) llvm/lib/Analysis/LazyValueInfo.cpp (+107-48)
  • (modified) llvm/test/Transforms/CorrelatedValuePropagation/cond-using-block-value.ll (+3-5)
diff --git a/llvm/lib/Analysis/LazyValueInfo.cpp b/llvm/lib/Analysis/LazyValueInfo.cpp
index 910f6b72afefe..88cf0dd2f36df 100644
--- a/llvm/lib/Analysis/LazyValueInfo.cpp
+++ b/llvm/lib/Analysis/LazyValueInfo.cpp
@@ -434,6 +434,28 @@ class LazyValueInfoImpl {
 
   void solve();
 
+  // For the following methods, if UseBlockValue is true, the function may
+  // push additional values to the worklist and return nullopt. If
+  // UseBlockValue is false, it will never return nullopt.
+
+  std::optional<ValueLatticeElement>
+  getValueFromSimpleICmpCondition(CmpInst::Predicate Pred, Value *RHS,
+                                  const APInt &Offset, Instruction *CxtI,
+                                  bool UseBlockValue);
+
+  std::optional<ValueLatticeElement>
+  getValueFromICmpCondition(Value *Val, ICmpInst *ICI, bool isTrueDest,
+                            bool UseBlockValue);
+
+  std::optional<ValueLatticeElement>
+  getValueFromCondition(Value *Val, Value *Cond, bool IsTrueDest,
+                        bool UseBlockValue, unsigned Depth = 0);
+
+  std::optional<ValueLatticeElement> getEdgeValueLocal(Value *Val,
+                                                       BasicBlock *BBFrom,
+                                                       BasicBlock *BBTo,
+                                                       bool UseBlockValue);
+
 public:
   /// This is the query interface to determine the lattice value for the
   /// specified Value* at the context instruction (if specified) or at the
@@ -755,14 +777,10 @@ LazyValueInfoImpl::solveBlockValuePHINode(PHINode *PN, BasicBlock *BB) {
   return Result;
 }
 
-static ValueLatticeElement getValueFromCondition(Value *Val, Value *Cond,
-                                                 bool isTrueDest = true,
-                                                 unsigned Depth = 0);
-
 // If we can determine a constraint on the value given conditions assumed by
 // the program, intersect those constraints with BBLV
 void LazyValueInfoImpl::intersectAssumeOrGuardBlockValueConstantRange(
-        Value *Val, ValueLatticeElement &BBLV, Instruction *BBI) {
+    Value *Val, ValueLatticeElement &BBLV, Instruction *BBI) {
   BBI = BBI ? BBI : dyn_cast<Instruction>(Val);
   if (!BBI)
     return;
@@ -779,17 +797,21 @@ void LazyValueInfoImpl::intersectAssumeOrGuardBlockValueConstantRange(
     if (I->getParent() != BB || !isValidAssumeForContext(I, BBI))
       continue;
 
-    BBLV = intersect(BBLV, getValueFromCondition(Val, I->getArgOperand(0)));
+    BBLV = intersect(BBLV, *getValueFromCondition(Val, I->getArgOperand(0),
+                                                  /*IsTrueDest*/ true,
+                                                  /*UseBlockValue*/ false));
   }
 
   // If guards are not used in the module, don't spend time looking for them
   if (GuardDecl && !GuardDecl->use_empty() &&
       BBI->getIterator() != BB->begin()) {
-    for (Instruction &I : make_range(std::next(BBI->getIterator().getReverse()),
-                                     BB->rend())) {
+    for (Instruction &I :
+         make_range(std::next(BBI->getIterator().getReverse()), BB->rend())) {
       Value *Cond = nullptr;
       if (match(&I, m_Intrinsic<Intrinsic::experimental_guard>(m_Value(Cond))))
-        BBLV = intersect(BBLV, getValueFromCondition(Val, Cond));
+        BBLV = intersect(BBLV,
+                         *getValueFromCondition(Val, Cond, /*IsTrueDest*/ true,
+                                                /*UseBlockValue*/ false));
     }
   }
 
@@ -886,10 +908,14 @@ LazyValueInfoImpl::solveBlockValueSelect(SelectInst *SI, BasicBlock *BB) {
   // If the value is undef, a different value may be chosen in
   // the select condition.
   if (isGuaranteedNotToBeUndef(Cond, AC)) {
-    TrueVal = intersect(TrueVal,
-                        getValueFromCondition(SI->getTrueValue(), Cond, true));
-    FalseVal = intersect(
-        FalseVal, getValueFromCondition(SI->getFalseValue(), Cond, false));
+    TrueVal =
+        intersect(TrueVal, *getValueFromCondition(SI->getTrueValue(), Cond,
+                                                  /*IsTrueDest*/ true,
+                                                  /*UseBlockValue*/ false));
+    FalseVal =
+        intersect(FalseVal, *getValueFromCondition(SI->getFalseValue(), Cond,
+                                                   /*IsTrueDest*/ false,
+                                                   /*UseBlockValue*/ false));
   }
 
   ValueLatticeElement Result = TrueVal;
@@ -1068,15 +1094,26 @@ static bool matchICmpOperand(APInt &Offset, Value *LHS, Value *Val,
 }
 
 /// Get value range for a "(Val + Offset) Pred RHS" condition.
-static ValueLatticeElement getValueFromSimpleICmpCondition(
-    CmpInst::Predicate Pred, Value *RHS, const APInt &Offset) {
+std::optional<ValueLatticeElement>
+LazyValueInfoImpl::getValueFromSimpleICmpCondition(CmpInst::Predicate Pred,
+                                                   Value *RHS,
+                                                   const APInt &Offset,
+                                                   Instruction *CxtI,
+                                                   bool UseBlockValue) {
   ConstantRange RHSRange(RHS->getType()->getIntegerBitWidth(),
                          /*isFullSet=*/true);
-  if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS))
+  if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) {
     RHSRange = ConstantRange(CI->getValue());
-  else if (Instruction *I = dyn_cast<Instruction>(RHS))
+  } else if (UseBlockValue) {
+    std::optional<ValueLatticeElement> R =
+        getBlockValue(RHS, CxtI->getParent(), CxtI);
+    if (!R)
+      return std::nullopt;
+    RHSRange = toConstantRange(*R, RHS->getType());
+  } else if (Instruction *I = dyn_cast<Instruction>(RHS)) {
     if (auto *Ranges = I->getMetadata(LLVMContext::MD_range))
       RHSRange = getConstantRangeFromMetadata(*Ranges);
+  }
 
   ConstantRange TrueValues =
       ConstantRange::makeAllowedICmpRegion(Pred, RHSRange);
@@ -1103,8 +1140,8 @@ getRangeViaSLT(CmpInst::Predicate Pred, APInt RHS,
   return std::nullopt;
 }
 
-static ValueLatticeElement getValueFromICmpCondition(Value *Val, ICmpInst *ICI,
-                                                     bool isTrueDest) {
+std::optional<ValueLatticeElement> LazyValueInfoImpl::getValueFromICmpCondition(
+    Value *Val, ICmpInst *ICI, bool isTrueDest, bool UseBlockValue) {
   Value *LHS = ICI->getOperand(0);
   Value *RHS = ICI->getOperand(1);
 
@@ -1128,11 +1165,13 @@ static ValueLatticeElement getValueFromICmpCondition(Value *Val, ICmpInst *ICI,
   unsigned BitWidth = Ty->getScalarSizeInBits();
   APInt Offset(BitWidth, 0);
   if (matchICmpOperand(Offset, LHS, Val, EdgePred))
-    return getValueFromSimpleICmpCondition(EdgePred, RHS, Offset);
+    return getValueFromSimpleICmpCondition(EdgePred, RHS, Offset, ICI,
+                                           UseBlockValue);
 
   CmpInst::Predicate SwappedPred = CmpInst::getSwappedPredicate(EdgePred);
   if (matchICmpOperand(Offset, RHS, Val, SwappedPred))
-    return getValueFromSimpleICmpCondition(SwappedPred, LHS, Offset);
+    return getValueFromSimpleICmpCondition(SwappedPred, LHS, Offset, ICI,
+                                           UseBlockValue);
 
   const APInt *Mask, *C;
   if (match(LHS, m_And(m_Specific(Val), m_APInt(Mask))) &&
@@ -1212,10 +1251,12 @@ static ValueLatticeElement getValueFromOverflowCondition(
   return ValueLatticeElement::getRange(NWR);
 }
 
-static ValueLatticeElement getValueFromCondition(
-    Value *Val, Value *Cond, bool IsTrueDest, unsigned Depth) {
+std::optional<ValueLatticeElement>
+LazyValueInfoImpl::getValueFromCondition(Value *Val, Value *Cond,
+                                         bool IsTrueDest, bool UseBlockValue,
+                                         unsigned Depth) {
   if (ICmpInst *ICI = dyn_cast<ICmpInst>(Cond))
-    return getValueFromICmpCondition(Val, ICI, IsTrueDest);
+    return getValueFromICmpCondition(Val, ICI, IsTrueDest, UseBlockValue);
 
   if (auto *EVI = dyn_cast<ExtractValueInst>(Cond))
     if (auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand()))
@@ -1227,7 +1268,7 @@ static ValueLatticeElement getValueFromCondition(
 
   Value *N;
   if (match(Cond, m_Not(m_Value(N))))
-    return getValueFromCondition(Val, N, !IsTrueDest, Depth);
+    return getValueFromCondition(Val, N, !IsTrueDest, UseBlockValue, Depth);
 
   Value *L, *R;
   bool IsAnd;
@@ -1238,19 +1279,23 @@ static ValueLatticeElement getValueFromCondition(
   else
     return ValueLatticeElement::getOverdefined();
 
-  ValueLatticeElement LV = getValueFromCondition(Val, L, IsTrueDest, Depth);
-  ValueLatticeElement RV = getValueFromCondition(Val, R, IsTrueDest, Depth);
+  std::optional<ValueLatticeElement> LV =
+      getValueFromCondition(Val, L, IsTrueDest, UseBlockValue, Depth);
+  std::optional<ValueLatticeElement> RV =
+      getValueFromCondition(Val, R, IsTrueDest, UseBlockValue, Depth);
+  if (!LV || !RV)
+    return std::nullopt;
 
   // if (L && R) -> intersect L and R
   // if (!(L || R)) -> intersect !L and !R
   // if (L || R) -> union L and R
   // if (!(L && R)) -> union !L and !R
   if (IsTrueDest ^ IsAnd) {
-    LV.mergeIn(RV);
-    return LV;
+    LV->mergeIn(*RV);
+    return *LV;
   }
 
-  return intersect(LV, RV);
+  return intersect(*LV, *RV);
 }
 
 // Return true if Usr has Op as an operand, otherwise false.
@@ -1302,8 +1347,9 @@ static ValueLatticeElement constantFoldUser(User *Usr, Value *Op,
 }
 
 /// Compute the value of Val on the edge BBFrom -> BBTo.
-static ValueLatticeElement getEdgeValueLocal(Value *Val, BasicBlock *BBFrom,
-                                             BasicBlock *BBTo) {
+std::optional<ValueLatticeElement>
+LazyValueInfoImpl::getEdgeValueLocal(Value *Val, BasicBlock *BBFrom,
+                                     BasicBlock *BBTo, bool UseBlockValue) {
   // TODO: Handle more complex conditionals. If (v == 0 || v2 < 1) is false, we
   // know that v != 0.
   if (BranchInst *BI = dyn_cast<BranchInst>(BBFrom->getTerminator())) {
@@ -1324,13 +1370,16 @@ static ValueLatticeElement getEdgeValueLocal(Value *Val, BasicBlock *BBFrom,
 
       // If the condition of the branch is an equality comparison, we may be
       // able to infer the value.
-      ValueLatticeElement Result = getValueFromCondition(Val, Condition,
-                                                         isTrueDest);
-      if (!Result.isOverdefined())
+      std::optional<ValueLatticeElement> Result =
+          getValueFromCondition(Val, Condition, isTrueDest, UseBlockValue);
+      if (!Result)
+        return std::nullopt;
+
+      if (!Result->isOverdefined())
         return Result;
 
       if (User *Usr = dyn_cast<User>(Val)) {
-        assert(Result.isOverdefined() && "Result isn't overdefined");
+        assert(Result->isOverdefined() && "Result isn't overdefined");
         // Check with isOperationFoldable() first to avoid linearly iterating
         // over the operands unnecessarily which can be expensive for
         // instructions with many operands.
@@ -1356,8 +1405,8 @@ static ValueLatticeElement getEdgeValueLocal(Value *Val, BasicBlock *BBFrom,
             //    br i1 %Condition, label %then, label %else
             for (unsigned i = 0; i < Usr->getNumOperands(); ++i) {
               Value *Op = Usr->getOperand(i);
-              ValueLatticeElement OpLatticeVal =
-                  getValueFromCondition(Op, Condition, isTrueDest);
+              ValueLatticeElement OpLatticeVal = *getValueFromCondition(
+                  Op, Condition, isTrueDest, /*UseBlockValue*/ false);
               if (std::optional<APInt> OpConst =
                       OpLatticeVal.asConstantInteger()) {
                 Result = constantFoldUser(Usr, Op, *OpConst, DL);
@@ -1367,7 +1416,7 @@ static ValueLatticeElement getEdgeValueLocal(Value *Val, BasicBlock *BBFrom,
           }
         }
       }
-      if (!Result.isOverdefined())
+      if (!Result->isOverdefined())
         return Result;
     }
   }
@@ -1432,8 +1481,12 @@ LazyValueInfoImpl::getEdgeValue(Value *Val, BasicBlock *BBFrom,
   if (Constant *VC = dyn_cast<Constant>(Val))
     return ValueLatticeElement::get(VC);
 
-  ValueLatticeElement LocalResult = getEdgeValueLocal(Val, BBFrom, BBTo);
-  if (hasSingleValue(LocalResult))
+  std::optional<ValueLatticeElement> LocalResult =
+      getEdgeValueLocal(Val, BBFrom, BBTo, /*UseBlockValue*/ true);
+  if (!LocalResult)
+    return std::nullopt;
+
+  if (hasSingleValue(*LocalResult))
     // Can't get any more precise here
     return LocalResult;
 
@@ -1453,7 +1506,7 @@ LazyValueInfoImpl::getEdgeValue(Value *Val, BasicBlock *BBFrom,
   // but then the result is not cached.
   intersectAssumeOrGuardBlockValueConstantRange(Val, InBlock, CxtI);
 
-  return intersect(LocalResult, InBlock);
+  return intersect(*LocalResult, InBlock);
 }
 
 ValueLatticeElement LazyValueInfoImpl::getValueInBlock(Value *V, BasicBlock *BB,
@@ -1499,10 +1552,12 @@ getValueOnEdge(Value *V, BasicBlock *FromBB, BasicBlock *ToBB,
 
   std::optional<ValueLatticeElement> Result =
       getEdgeValue(V, FromBB, ToBB, CxtI);
-  if (!Result) {
+  while (!Result) {
+    // As the worklist only explicitly tracks block values (but not edge values)
+    // we may have to call solve() multiple times, as the edge value calculation
+    // may request additional block values.
     solve();
     Result = getEdgeValue(V, FromBB, ToBB, CxtI);
-    assert(Result && "More work to do after problem solved?");
   }
 
   LLVM_DEBUG(dbgs() << "  Result = " << *Result << "\n");
@@ -1528,13 +1583,17 @@ ValueLatticeElement LazyValueInfoImpl::getValueAtUse(const Use &U) {
       if (!isGuaranteedNotToBeUndef(SI->getCondition(), AC))
         break;
       if (CurrU->getOperandNo() == 1)
-        CondVal = getValueFromCondition(V, SI->getCondition(), true);
+        CondVal =
+            *getValueFromCondition(V, SI->getCondition(), /*IsTrueDest*/ true,
+                                   /*UseBlockValue*/ false);
       else if (CurrU->getOperandNo() == 2)
-        CondVal = getValueFromCondition(V, SI->getCondition(), false);
+        CondVal =
+            *getValueFromCondition(V, SI->getCondition(), /*IsTrueDest*/ false,
+                                   /*UseBlockValue*/ false);
     } else if (auto *PHI = dyn_cast<PHINode>(CurrI)) {
       // TODO: Use non-local query?
-      CondVal =
-          getEdgeValueLocal(V, PHI->getIncomingBlock(*CurrU), PHI->getParent());
+      CondVal = *getEdgeValueLocal(V, PHI->getIncomingBlock(*CurrU),
+                                   PHI->getParent(), /*UseBlockValue*/ false);
     }
     if (CondVal)
       VL = intersect(VL, *CondVal);
diff --git a/llvm/test/Transforms/CorrelatedValuePropagation/cond-using-block-value.ll b/llvm/test/Transforms/CorrelatedValuePropagation/cond-using-block-value.ll
index d30b31d317a6d..252f6596cedc5 100644
--- a/llvm/test/Transforms/CorrelatedValuePropagation/cond-using-block-value.ll
+++ b/llvm/test/Transforms/CorrelatedValuePropagation/cond-using-block-value.ll
@@ -12,8 +12,7 @@ define void @test_icmp_from_implied_cond(i32 %a, i32 %b) {
 ; CHECK-NEXT:    [[COND:%.*]] = icmp ult i32 [[B]], [[A]]
 ; CHECK-NEXT:    br i1 [[COND]], label [[L2:%.*]], label [[END]]
 ; CHECK:       l2:
-; CHECK-NEXT:    [[B_CMP1:%.*]] = icmp ult i32 [[B]], 32
-; CHECK-NEXT:    call void @use(i1 [[B_CMP1]])
+; CHECK-NEXT:    call void @use(i1 true)
 ; CHECK-NEXT:    [[B_CMP2:%.*]] = icmp ult i32 [[B]], 31
 ; CHECK-NEXT:    call void @use(i1 [[B_CMP2]])
 ; CHECK-NEXT:    ret void
@@ -47,7 +46,7 @@ define i64 @test_sext_from_implied_cond(i32 %a, i32 %b) {
 ; CHECK-NEXT:    [[COND:%.*]] = icmp ult i32 [[B]], [[A]]
 ; CHECK-NEXT:    br i1 [[COND]], label [[L2:%.*]], label [[END]]
 ; CHECK:       l2:
-; CHECK-NEXT:    [[SEXT:%.*]] = sext i32 [[B]] to i64
+; CHECK-NEXT:    [[SEXT:%.*]] = zext nneg i32 [[B]] to i64
 ; CHECK-NEXT:    ret i64 [[SEXT]]
 ; CHECK:       end:
 ; CHECK-NEXT:    ret i64 0
@@ -74,8 +73,7 @@ define void @test_icmp_from_implied_range(i16 %x, i32 %b) {
 ; CHECK-NEXT:    [[COND:%.*]] = icmp ult i32 [[B]], [[A]]
 ; CHECK-NEXT:    br i1 [[COND]], label [[L1:%.*]], label [[END:%.*]]
 ; CHECK:       l1:
-; CHECK-NEXT:    [[B_CMP1:%.*]] = icmp ult i32 [[B]], 65535
-; CHECK-NEXT:    call void @use(i1 [[B_CMP1]])
+; CHECK-NEXT:    call void @use(i1 true)
 ; CHECK-NEXT:    [[B_CMP2:%.*]] = icmp ult i32 [[B]], 65534
 ; CHECK-NEXT:    call void @use(i1 [[B_CMP2]])
 ; CHECK-NEXT:    ret void

@llvmbot
Copy link
Member

llvmbot commented Dec 13, 2023

@llvm/pr-subscribers-llvm-analysis

Author: Nikita Popov (nikic)

Changes

Currently, LVI will only use conditions like "X < C" to constrain the value of X on the relevant edge. This patch extends it to handle conditions like "X < Y" by querying the known range of Y.

This means that getValueFromCondition() and various related APIs can now return nullopt to indicate that they have pushed to the worklist, and need to be called again later. This behavior is currently controlled by a UseBlockValue option, and only enabled for actual edge value handling. All other places deriving constraints from conditions keep using the previous logic for now.

This change was originally motivated as a fix for the regression reported in #73662 (comment). Unfortunately, it doesn't actually fix it, because we run into another issue there (LVI currently is really bad at handling values used in loops).

This change has some compile-time impact (http://llvm-compile-time-tracker.com/compare.php?from=ed2d497291f0de330e27109ce21375b41597b4a4&amp;to=a0d36310343898db0aea04f02682f6ff72d0b60f&amp;stat=instructions%3Au), but it's fairly small, in the 0.05% range.


Full diff: https://github.com/llvm/llvm-project/pull/75311.diff

2 Files Affected:

  • (modified) llvm/lib/Analysis/LazyValueInfo.cpp (+107-48)
  • (modified) llvm/test/Transforms/CorrelatedValuePropagation/cond-using-block-value.ll (+3-5)
diff --git a/llvm/lib/Analysis/LazyValueInfo.cpp b/llvm/lib/Analysis/LazyValueInfo.cpp
index 910f6b72afefe2..88cf0dd2f36df0 100644
--- a/llvm/lib/Analysis/LazyValueInfo.cpp
+++ b/llvm/lib/Analysis/LazyValueInfo.cpp
@@ -434,6 +434,28 @@ class LazyValueInfoImpl {
 
   void solve();
 
+  // For the following methods, if UseBlockValue is true, the function may
+  // push additional values to the worklist and return nullopt. If
+  // UseBlockValue is false, it will never return nullopt.
+
+  std::optional<ValueLatticeElement>
+  getValueFromSimpleICmpCondition(CmpInst::Predicate Pred, Value *RHS,
+                                  const APInt &Offset, Instruction *CxtI,
+                                  bool UseBlockValue);
+
+  std::optional<ValueLatticeElement>
+  getValueFromICmpCondition(Value *Val, ICmpInst *ICI, bool isTrueDest,
+                            bool UseBlockValue);
+
+  std::optional<ValueLatticeElement>
+  getValueFromCondition(Value *Val, Value *Cond, bool IsTrueDest,
+                        bool UseBlockValue, unsigned Depth = 0);
+
+  std::optional<ValueLatticeElement> getEdgeValueLocal(Value *Val,
+                                                       BasicBlock *BBFrom,
+                                                       BasicBlock *BBTo,
+                                                       bool UseBlockValue);
+
 public:
   /// This is the query interface to determine the lattice value for the
   /// specified Value* at the context instruction (if specified) or at the
@@ -755,14 +777,10 @@ LazyValueInfoImpl::solveBlockValuePHINode(PHINode *PN, BasicBlock *BB) {
   return Result;
 }
 
-static ValueLatticeElement getValueFromCondition(Value *Val, Value *Cond,
-                                                 bool isTrueDest = true,
-                                                 unsigned Depth = 0);
-
 // If we can determine a constraint on the value given conditions assumed by
 // the program, intersect those constraints with BBLV
 void LazyValueInfoImpl::intersectAssumeOrGuardBlockValueConstantRange(
-        Value *Val, ValueLatticeElement &BBLV, Instruction *BBI) {
+    Value *Val, ValueLatticeElement &BBLV, Instruction *BBI) {
   BBI = BBI ? BBI : dyn_cast<Instruction>(Val);
   if (!BBI)
     return;
@@ -779,17 +797,21 @@ void LazyValueInfoImpl::intersectAssumeOrGuardBlockValueConstantRange(
     if (I->getParent() != BB || !isValidAssumeForContext(I, BBI))
       continue;
 
-    BBLV = intersect(BBLV, getValueFromCondition(Val, I->getArgOperand(0)));
+    BBLV = intersect(BBLV, *getValueFromCondition(Val, I->getArgOperand(0),
+                                                  /*IsTrueDest*/ true,
+                                                  /*UseBlockValue*/ false));
   }
 
   // If guards are not used in the module, don't spend time looking for them
   if (GuardDecl && !GuardDecl->use_empty() &&
       BBI->getIterator() != BB->begin()) {
-    for (Instruction &I : make_range(std::next(BBI->getIterator().getReverse()),
-                                     BB->rend())) {
+    for (Instruction &I :
+         make_range(std::next(BBI->getIterator().getReverse()), BB->rend())) {
       Value *Cond = nullptr;
       if (match(&I, m_Intrinsic<Intrinsic::experimental_guard>(m_Value(Cond))))
-        BBLV = intersect(BBLV, getValueFromCondition(Val, Cond));
+        BBLV = intersect(BBLV,
+                         *getValueFromCondition(Val, Cond, /*IsTrueDest*/ true,
+                                                /*UseBlockValue*/ false));
     }
   }
 
@@ -886,10 +908,14 @@ LazyValueInfoImpl::solveBlockValueSelect(SelectInst *SI, BasicBlock *BB) {
   // If the value is undef, a different value may be chosen in
   // the select condition.
   if (isGuaranteedNotToBeUndef(Cond, AC)) {
-    TrueVal = intersect(TrueVal,
-                        getValueFromCondition(SI->getTrueValue(), Cond, true));
-    FalseVal = intersect(
-        FalseVal, getValueFromCondition(SI->getFalseValue(), Cond, false));
+    TrueVal =
+        intersect(TrueVal, *getValueFromCondition(SI->getTrueValue(), Cond,
+                                                  /*IsTrueDest*/ true,
+                                                  /*UseBlockValue*/ false));
+    FalseVal =
+        intersect(FalseVal, *getValueFromCondition(SI->getFalseValue(), Cond,
+                                                   /*IsTrueDest*/ false,
+                                                   /*UseBlockValue*/ false));
   }
 
   ValueLatticeElement Result = TrueVal;
@@ -1068,15 +1094,26 @@ static bool matchICmpOperand(APInt &Offset, Value *LHS, Value *Val,
 }
 
 /// Get value range for a "(Val + Offset) Pred RHS" condition.
-static ValueLatticeElement getValueFromSimpleICmpCondition(
-    CmpInst::Predicate Pred, Value *RHS, const APInt &Offset) {
+std::optional<ValueLatticeElement>
+LazyValueInfoImpl::getValueFromSimpleICmpCondition(CmpInst::Predicate Pred,
+                                                   Value *RHS,
+                                                   const APInt &Offset,
+                                                   Instruction *CxtI,
+                                                   bool UseBlockValue) {
   ConstantRange RHSRange(RHS->getType()->getIntegerBitWidth(),
                          /*isFullSet=*/true);
-  if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS))
+  if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) {
     RHSRange = ConstantRange(CI->getValue());
-  else if (Instruction *I = dyn_cast<Instruction>(RHS))
+  } else if (UseBlockValue) {
+    std::optional<ValueLatticeElement> R =
+        getBlockValue(RHS, CxtI->getParent(), CxtI);
+    if (!R)
+      return std::nullopt;
+    RHSRange = toConstantRange(*R, RHS->getType());
+  } else if (Instruction *I = dyn_cast<Instruction>(RHS)) {
     if (auto *Ranges = I->getMetadata(LLVMContext::MD_range))
       RHSRange = getConstantRangeFromMetadata(*Ranges);
+  }
 
   ConstantRange TrueValues =
       ConstantRange::makeAllowedICmpRegion(Pred, RHSRange);
@@ -1103,8 +1140,8 @@ getRangeViaSLT(CmpInst::Predicate Pred, APInt RHS,
   return std::nullopt;
 }
 
-static ValueLatticeElement getValueFromICmpCondition(Value *Val, ICmpInst *ICI,
-                                                     bool isTrueDest) {
+std::optional<ValueLatticeElement> LazyValueInfoImpl::getValueFromICmpCondition(
+    Value *Val, ICmpInst *ICI, bool isTrueDest, bool UseBlockValue) {
   Value *LHS = ICI->getOperand(0);
   Value *RHS = ICI->getOperand(1);
 
@@ -1128,11 +1165,13 @@ static ValueLatticeElement getValueFromICmpCondition(Value *Val, ICmpInst *ICI,
   unsigned BitWidth = Ty->getScalarSizeInBits();
   APInt Offset(BitWidth, 0);
   if (matchICmpOperand(Offset, LHS, Val, EdgePred))
-    return getValueFromSimpleICmpCondition(EdgePred, RHS, Offset);
+    return getValueFromSimpleICmpCondition(EdgePred, RHS, Offset, ICI,
+                                           UseBlockValue);
 
   CmpInst::Predicate SwappedPred = CmpInst::getSwappedPredicate(EdgePred);
   if (matchICmpOperand(Offset, RHS, Val, SwappedPred))
-    return getValueFromSimpleICmpCondition(SwappedPred, LHS, Offset);
+    return getValueFromSimpleICmpCondition(SwappedPred, LHS, Offset, ICI,
+                                           UseBlockValue);
 
   const APInt *Mask, *C;
   if (match(LHS, m_And(m_Specific(Val), m_APInt(Mask))) &&
@@ -1212,10 +1251,12 @@ static ValueLatticeElement getValueFromOverflowCondition(
   return ValueLatticeElement::getRange(NWR);
 }
 
-static ValueLatticeElement getValueFromCondition(
-    Value *Val, Value *Cond, bool IsTrueDest, unsigned Depth) {
+std::optional<ValueLatticeElement>
+LazyValueInfoImpl::getValueFromCondition(Value *Val, Value *Cond,
+                                         bool IsTrueDest, bool UseBlockValue,
+                                         unsigned Depth) {
   if (ICmpInst *ICI = dyn_cast<ICmpInst>(Cond))
-    return getValueFromICmpCondition(Val, ICI, IsTrueDest);
+    return getValueFromICmpCondition(Val, ICI, IsTrueDest, UseBlockValue);
 
   if (auto *EVI = dyn_cast<ExtractValueInst>(Cond))
     if (auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand()))
@@ -1227,7 +1268,7 @@ static ValueLatticeElement getValueFromCondition(
 
   Value *N;
   if (match(Cond, m_Not(m_Value(N))))
-    return getValueFromCondition(Val, N, !IsTrueDest, Depth);
+    return getValueFromCondition(Val, N, !IsTrueDest, UseBlockValue, Depth);
 
   Value *L, *R;
   bool IsAnd;
@@ -1238,19 +1279,23 @@ static ValueLatticeElement getValueFromCondition(
   else
     return ValueLatticeElement::getOverdefined();
 
-  ValueLatticeElement LV = getValueFromCondition(Val, L, IsTrueDest, Depth);
-  ValueLatticeElement RV = getValueFromCondition(Val, R, IsTrueDest, Depth);
+  std::optional<ValueLatticeElement> LV =
+      getValueFromCondition(Val, L, IsTrueDest, UseBlockValue, Depth);
+  std::optional<ValueLatticeElement> RV =
+      getValueFromCondition(Val, R, IsTrueDest, UseBlockValue, Depth);
+  if (!LV || !RV)
+    return std::nullopt;
 
   // if (L && R) -> intersect L and R
   // if (!(L || R)) -> intersect !L and !R
   // if (L || R) -> union L and R
   // if (!(L && R)) -> union !L and !R
   if (IsTrueDest ^ IsAnd) {
-    LV.mergeIn(RV);
-    return LV;
+    LV->mergeIn(*RV);
+    return *LV;
   }
 
-  return intersect(LV, RV);
+  return intersect(*LV, *RV);
 }
 
 // Return true if Usr has Op as an operand, otherwise false.
@@ -1302,8 +1347,9 @@ static ValueLatticeElement constantFoldUser(User *Usr, Value *Op,
 }
 
 /// Compute the value of Val on the edge BBFrom -> BBTo.
-static ValueLatticeElement getEdgeValueLocal(Value *Val, BasicBlock *BBFrom,
-                                             BasicBlock *BBTo) {
+std::optional<ValueLatticeElement>
+LazyValueInfoImpl::getEdgeValueLocal(Value *Val, BasicBlock *BBFrom,
+                                     BasicBlock *BBTo, bool UseBlockValue) {
   // TODO: Handle more complex conditionals. If (v == 0 || v2 < 1) is false, we
   // know that v != 0.
   if (BranchInst *BI = dyn_cast<BranchInst>(BBFrom->getTerminator())) {
@@ -1324,13 +1370,16 @@ static ValueLatticeElement getEdgeValueLocal(Value *Val, BasicBlock *BBFrom,
 
       // If the condition of the branch is an equality comparison, we may be
       // able to infer the value.
-      ValueLatticeElement Result = getValueFromCondition(Val, Condition,
-                                                         isTrueDest);
-      if (!Result.isOverdefined())
+      std::optional<ValueLatticeElement> Result =
+          getValueFromCondition(Val, Condition, isTrueDest, UseBlockValue);
+      if (!Result)
+        return std::nullopt;
+
+      if (!Result->isOverdefined())
         return Result;
 
       if (User *Usr = dyn_cast<User>(Val)) {
-        assert(Result.isOverdefined() && "Result isn't overdefined");
+        assert(Result->isOverdefined() && "Result isn't overdefined");
         // Check with isOperationFoldable() first to avoid linearly iterating
         // over the operands unnecessarily which can be expensive for
         // instructions with many operands.
@@ -1356,8 +1405,8 @@ static ValueLatticeElement getEdgeValueLocal(Value *Val, BasicBlock *BBFrom,
             //    br i1 %Condition, label %then, label %else
             for (unsigned i = 0; i < Usr->getNumOperands(); ++i) {
               Value *Op = Usr->getOperand(i);
-              ValueLatticeElement OpLatticeVal =
-                  getValueFromCondition(Op, Condition, isTrueDest);
+              ValueLatticeElement OpLatticeVal = *getValueFromCondition(
+                  Op, Condition, isTrueDest, /*UseBlockValue*/ false);
               if (std::optional<APInt> OpConst =
                       OpLatticeVal.asConstantInteger()) {
                 Result = constantFoldUser(Usr, Op, *OpConst, DL);
@@ -1367,7 +1416,7 @@ static ValueLatticeElement getEdgeValueLocal(Value *Val, BasicBlock *BBFrom,
           }
         }
       }
-      if (!Result.isOverdefined())
+      if (!Result->isOverdefined())
         return Result;
     }
   }
@@ -1432,8 +1481,12 @@ LazyValueInfoImpl::getEdgeValue(Value *Val, BasicBlock *BBFrom,
   if (Constant *VC = dyn_cast<Constant>(Val))
     return ValueLatticeElement::get(VC);
 
-  ValueLatticeElement LocalResult = getEdgeValueLocal(Val, BBFrom, BBTo);
-  if (hasSingleValue(LocalResult))
+  std::optional<ValueLatticeElement> LocalResult =
+      getEdgeValueLocal(Val, BBFrom, BBTo, /*UseBlockValue*/ true);
+  if (!LocalResult)
+    return std::nullopt;
+
+  if (hasSingleValue(*LocalResult))
     // Can't get any more precise here
     return LocalResult;
 
@@ -1453,7 +1506,7 @@ LazyValueInfoImpl::getEdgeValue(Value *Val, BasicBlock *BBFrom,
   // but then the result is not cached.
   intersectAssumeOrGuardBlockValueConstantRange(Val, InBlock, CxtI);
 
-  return intersect(LocalResult, InBlock);
+  return intersect(*LocalResult, InBlock);
 }
 
 ValueLatticeElement LazyValueInfoImpl::getValueInBlock(Value *V, BasicBlock *BB,
@@ -1499,10 +1552,12 @@ getValueOnEdge(Value *V, BasicBlock *FromBB, BasicBlock *ToBB,
 
   std::optional<ValueLatticeElement> Result =
       getEdgeValue(V, FromBB, ToBB, CxtI);
-  if (!Result) {
+  while (!Result) {
+    // As the worklist only explicitly tracks block values (but not edge values)
+    // we may have to call solve() multiple times, as the edge value calculation
+    // may request additional block values.
     solve();
     Result = getEdgeValue(V, FromBB, ToBB, CxtI);
-    assert(Result && "More work to do after problem solved?");
   }
 
   LLVM_DEBUG(dbgs() << "  Result = " << *Result << "\n");
@@ -1528,13 +1583,17 @@ ValueLatticeElement LazyValueInfoImpl::getValueAtUse(const Use &U) {
       if (!isGuaranteedNotToBeUndef(SI->getCondition(), AC))
         break;
       if (CurrU->getOperandNo() == 1)
-        CondVal = getValueFromCondition(V, SI->getCondition(), true);
+        CondVal =
+            *getValueFromCondition(V, SI->getCondition(), /*IsTrueDest*/ true,
+                                   /*UseBlockValue*/ false);
       else if (CurrU->getOperandNo() == 2)
-        CondVal = getValueFromCondition(V, SI->getCondition(), false);
+        CondVal =
+            *getValueFromCondition(V, SI->getCondition(), /*IsTrueDest*/ false,
+                                   /*UseBlockValue*/ false);
     } else if (auto *PHI = dyn_cast<PHINode>(CurrI)) {
       // TODO: Use non-local query?
-      CondVal =
-          getEdgeValueLocal(V, PHI->getIncomingBlock(*CurrU), PHI->getParent());
+      CondVal = *getEdgeValueLocal(V, PHI->getIncomingBlock(*CurrU),
+                                   PHI->getParent(), /*UseBlockValue*/ false);
     }
     if (CondVal)
       VL = intersect(VL, *CondVal);
diff --git a/llvm/test/Transforms/CorrelatedValuePropagation/cond-using-block-value.ll b/llvm/test/Transforms/CorrelatedValuePropagation/cond-using-block-value.ll
index d30b31d317a6de..252f6596cedc5e 100644
--- a/llvm/test/Transforms/CorrelatedValuePropagation/cond-using-block-value.ll
+++ b/llvm/test/Transforms/CorrelatedValuePropagation/cond-using-block-value.ll
@@ -12,8 +12,7 @@ define void @test_icmp_from_implied_cond(i32 %a, i32 %b) {
 ; CHECK-NEXT:    [[COND:%.*]] = icmp ult i32 [[B]], [[A]]
 ; CHECK-NEXT:    br i1 [[COND]], label [[L2:%.*]], label [[END]]
 ; CHECK:       l2:
-; CHECK-NEXT:    [[B_CMP1:%.*]] = icmp ult i32 [[B]], 32
-; CHECK-NEXT:    call void @use(i1 [[B_CMP1]])
+; CHECK-NEXT:    call void @use(i1 true)
 ; CHECK-NEXT:    [[B_CMP2:%.*]] = icmp ult i32 [[B]], 31
 ; CHECK-NEXT:    call void @use(i1 [[B_CMP2]])
 ; CHECK-NEXT:    ret void
@@ -47,7 +46,7 @@ define i64 @test_sext_from_implied_cond(i32 %a, i32 %b) {
 ; CHECK-NEXT:    [[COND:%.*]] = icmp ult i32 [[B]], [[A]]
 ; CHECK-NEXT:    br i1 [[COND]], label [[L2:%.*]], label [[END]]
 ; CHECK:       l2:
-; CHECK-NEXT:    [[SEXT:%.*]] = sext i32 [[B]] to i64
+; CHECK-NEXT:    [[SEXT:%.*]] = zext nneg i32 [[B]] to i64
 ; CHECK-NEXT:    ret i64 [[SEXT]]
 ; CHECK:       end:
 ; CHECK-NEXT:    ret i64 0
@@ -74,8 +73,7 @@ define void @test_icmp_from_implied_range(i16 %x, i32 %b) {
 ; CHECK-NEXT:    [[COND:%.*]] = icmp ult i32 [[B]], [[A]]
 ; CHECK-NEXT:    br i1 [[COND]], label [[L1:%.*]], label [[END:%.*]]
 ; CHECK:       l1:
-; CHECK-NEXT:    [[B_CMP1:%.*]] = icmp ult i32 [[B]], 65535
-; CHECK-NEXT:    call void @use(i1 [[B_CMP1]])
+; CHECK-NEXT:    call void @use(i1 true)
 ; CHECK-NEXT:    [[B_CMP2:%.*]] = icmp ult i32 [[B]], 65534
 ; CHECK-NEXT:    call void @use(i1 [[B_CMP2]])
 ; CHECK-NEXT:    ret void

RHSRange = ConstantRange(CI->getValue());
else if (Instruction *I = dyn_cast<Instruction>(RHS))
} else if (UseBlockValue) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the core change, the rest is just threading the new nullopt return value through everything.

@llvmbot llvmbot added the clang Clang issues not falling into any other category label Dec 13, 2023
Currently, LVI will only use conditions like "X < C" to constrain
the value of X on the relevant edge. This patch extends it to
handle conditions like "X < Y" by querying the known range of Y.

This means that getValueFromCondition() and various related APIs
can now return nullopt to indicate that they have pushed to the
worklist, and need to be called again later. This behavior is
currently controlled by a UseBlockValue option, and only enabled
for actual edge value handling. All other places deriving
constraints from conditions keep using the previous logic for
now.
@nikic nikic force-pushed the lvi-val-from-cond branch from a4127e5 to fada145 Compare December 20, 2023 09:29
@nikic
Copy link
Contributor Author

nikic commented Dec 20, 2023

ping

@dtcxzyw
Copy link
Member

dtcxzyw commented Dec 29, 2023

@nikic Could you please have a look at dtcxzyw/llvm-opt-benchmark#12 (comment)?

BTW, is it expected that this patch converts a bunch of signed icmps into unsigned icmps?
See also the comment in InstCombine:

{
// Don't use dominating conditions when folding icmp using known bits. This
// may convert signed into unsigned predicates in ways that other passes
// (especially IndVarSimplify) may not be able to reliably undo.
SQ.DC = nullptr;
auto _ = make_scope_exit([&]() { SQ.DC = &DC; });
if (SimplifyDemandedBits(&I, 0, getDemandedBitsLHSMask(I, BitWidth),
Op0Known, 0))
return &I;
if (SimplifyDemandedBits(&I, 1, APInt::getAllOnes(BitWidth), Op1Known, 0))
return &I;
}

Copy link
Member

@dtcxzyw dtcxzyw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation LGTM.

ConstantRange RHSRange(RHS->getType()->getIntegerBitWidth(),
/*isFullSet=*/true);
if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS))
if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drop these braces.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per https://llvm.org/docs/CodingStandards.html#don-t-use-braces-on-simple-single-statement-bodies-of-if-else-loop-statements the preference is to use braces on all branches if one of them uses braces.

@nikic nikic merged commit 734ee0e into llvm:main Jan 2, 2024
4 checks passed
@nikic
Copy link
Contributor Author

nikic commented Jan 2, 2024

@dtcxzyw I believe this regression is due to an evaluation order issue in LVI. I've filed #76705.

dtcxzyw pushed a commit that referenced this pull request Apr 4, 2024
This adds handling of range attribute for return values of Call and
Invoke in getFromRangeMetadata and handling of argument with range
attribute in solveBlockValueNonLocal.
There is one additional check of the range metadata at line 1120 in
getValueFromSimpleICmpCondition that is not covered in this PR as after
#75311 there is no test that
cover that check any more and I have not been able to create a test that
trigger that code.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang Clang issues not falling into any other category llvm:analysis llvm:transforms
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants