Skip to content

Commit

Permalink
Combine (shl (and x, 2^n-1), n) to (shl x, n)
Browse files Browse the repository at this point in the history
  • Loading branch information
abhinay-anubola committed Dec 4, 2024
1 parent d173e6c commit 84fedf7
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 3 deletions.
4 changes: 4 additions & 0 deletions llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,10 @@ class CombinerHelper {
void applyCombineTruncOfExt(MachineInstr &MI,
std::pair<Register, unsigned> &MatchInfo);

/// Transform (shl (and x, 2^n-1), n) to (shl x, n)
bool matchCombineShlOfAnd(MachineInstr &MI, Register &Reg);
void applyCombineShlOfAnd(MachineInstr &MI, Register &Reg);

/// Transform trunc (shl x, K) to shl (trunc x), K
/// if K < VT.getScalarSizeInBits().
///
Expand Down
10 changes: 9 additions & 1 deletion llvm/include/llvm/Target/GlobalISel/Combine.td
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,14 @@ def trunc_ext_fold: GICombineRule <
(apply [{ Helper.applyCombineTruncOfExt(*${root}, ${matchinfo}); }])
>;

// Fold (shl (and x, 2^n-1), n) -> (shl x, n)
def shl_and_fold: GICombineRule <
(defs root:$root, register_matchinfo:$matchinfo),
(match (wip_match_opcode G_SHL):$root,
[{ return Helper.matchCombineShlOfAnd(*${root}, ${matchinfo}); }]),
(apply [{ Helper.applyCombineShlOfAnd(*${root}, ${matchinfo}); }])
>;

// Under certain conditions, transform:
// trunc (shl x, K) -> shl (trunc x), K//
// trunc ([al]shr x, K) -> (trunc ([al]shr (trunc x), K))
Expand Down Expand Up @@ -1598,7 +1606,7 @@ def const_combines : GICombineGroup<[constant_fold_fp_ops, const_ptradd_to_i2p,
def known_bits_simplifications : GICombineGroup<[
redundant_and, redundant_sext_inreg, redundant_or, urem_pow2_to_mask,
zext_trunc_fold, icmp_to_true_false_known_bits, icmp_to_lhs_known_bits,
sext_inreg_to_zext_inreg]>;
sext_inreg_to_zext_inreg, shl_and_fold]>;

def width_reduction_combines : GICombineGroup<[reduce_shl_of_extend,
narrow_binop_feeding_and]>;
Expand Down
28 changes: 28 additions & 0 deletions llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2553,6 +2553,34 @@ void CombinerHelper::applyCombineTruncOfExt(
MI.eraseFromParent();
}

bool CombinerHelper::matchCombineShlOfAnd(MachineInstr &MI, Register &Reg) {
assert(MI.getOpcode() == TargetOpcode::G_SHL && "Expected a G_SHL");
Register SrcReg = MI.getOperand(1).getReg();
Register ShftReg = MI.getOperand(2).getReg();

const auto ShftVal = getIConstantVRegSExtVal(ShftReg, MRI);
if (!ShftVal)
return false;
MachineInstr *SrcMI = MRI.getVRegDef(SrcReg);
if (SrcMI->getOpcode() != TargetOpcode::G_AND)
return false;

Register AndRHS = SrcMI->getOperand(2).getReg();
// Find the mask on the RHS.
const auto Cst = getIConstantVRegValWithLookThrough(AndRHS, MRI);
if (!Cst || !Cst->Value.isMask(ShftVal.value()))
return false;
Reg = SrcMI->getOperand(1).getReg();
return true;
}

void CombinerHelper::applyCombineShlOfAnd(MachineInstr &MI, Register &Reg) {
assert(MI.getOpcode() == TargetOpcode::G_SHL && "Expected a G_SHL");
Observer.changingInstr(MI);
MI.getOperand(1).setReg(Reg);
Observer.changedInstr(MI);
}

static LLT getMidVTForTruncRightShiftCombine(LLT ShiftTy, LLT TruncTy) {
const unsigned ShiftSize = ShiftTy.getScalarSizeInBits();
const unsigned TruncSize = TruncTy.getScalarSizeInBits();
Expand Down
4 changes: 2 additions & 2 deletions llvm/test/CodeGen/AIE/aie2/bfloat16_to_float.ll
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ define dso_local noundef float @bfloat16_to_float_test(%class.bfloat16 %bf.coerc
; CHECK: .p2align 4
; CHECK-NEXT: // %bb.0: // %entry
; CHECK-NEXT: nopb ; nopa ; nops ; ret lr ; nopm ; nopv
; CHECK-NEXT: nop // Delay Slot 5
; CHECK-NEXT: nopx // Delay Slot 5
; CHECK-NEXT: nop // Delay Slot 4
; CHECK-NEXT: nop // Delay Slot 3
; CHECK-NEXT: mova r0, #16; extend.u16 r1, r1 // Delay Slot 2
; CHECK-NEXT: mova r0, #16 // Delay Slot 2
; CHECK-NEXT: lshl r0, r1, r0 // Delay Slot 1
entry:
%bf.coerce.fca.0.extract = extractvalue %class.bfloat16 %bf.coerce, 0
Expand Down

0 comments on commit 84fedf7

Please sign in to comment.