Skip to content

Commit

Permalink
ARM64-SVE: Add ShiftRightArithmeticForDivide (#104279)
Browse files Browse the repository at this point in the history
  • Loading branch information
amanasifkhalid authored Jul 3, 2024
1 parent 4f96b8f commit 89d63f9
Show file tree
Hide file tree
Showing 8 changed files with 142 additions and 13 deletions.
40 changes: 28 additions & 12 deletions src/coreclr/jit/hwintrinsiccodegenarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,8 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
{
assert(instrIsRMW);

insScalableOpts sopt;
insScalableOpts sopt = INS_SCALABLE_OPTS_NONE;
bool hasShift = false;

switch (intrinEmbMask.id)
{
Expand All @@ -601,17 +602,34 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
{
assert(emitter::optGetSveInsOpt(op2Size) == INS_OPTS_SCALABLE_D);
sopt = INS_SCALABLE_OPTS_WIDE;
break;
}

FALLTHROUGH;
break;
}

case NI_Sve_ShiftRightArithmeticForDivide:
hasShift = true;
break;

default:
sopt = INS_SCALABLE_OPTS_NONE;
break;
}

auto emitInsHelper = [&](regNumber reg1, regNumber reg2, regNumber reg3) {
if (hasShift)
{
HWIntrinsicImmOpHelper helper(this, intrinEmbMask.op2, op2->AsHWIntrinsic());
for (helper.EmitBegin(); !helper.Done(); helper.EmitCaseEnd())
{
GetEmitter()->emitInsSve_R_R_I(insEmbMask, emitSize, reg1, reg2, helper.ImmValue(), opt,
sopt);
}
}
else
{
GetEmitter()->emitIns_R_R_R(insEmbMask, emitSize, reg1, reg2, reg3, opt, sopt);
}
};

if (intrin.op3->IsVectorZero())
{
// If `falseReg` is zero, then move the first operand of `intrinEmbMask` in the
Expand All @@ -622,7 +640,7 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)

// Finally, perform the actual "predicated" operation so that `targetReg` is the first operand
// and `embMaskOp2Reg` is the second operand.
GetEmitter()->emitIns_R_R_R(insEmbMask, emitSize, targetReg, maskReg, embMaskOp2Reg, opt, sopt);
emitInsHelper(targetReg, maskReg, embMaskOp2Reg);
}
else if (targetReg != falseReg)
{
Expand All @@ -636,8 +654,7 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
{
// If the embedded instruction supports optional mask operation, use the "unpredicated"
// version of the instruction, followed by "sel" to select the active lanes.
GetEmitter()->emitIns_R_R_R(insEmbMask, emitSize, targetReg, embMaskOp1Reg,
embMaskOp2Reg, opt, sopt);
emitInsHelper(targetReg, embMaskOp1Reg, embMaskOp2Reg);
}
else
{
Expand All @@ -651,8 +668,7 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)

GetEmitter()->emitIns_R_R(INS_sve_movprfx, EA_SCALABLE, targetReg, embMaskOp1Reg);

GetEmitter()->emitIns_R_R_R(insEmbMask, emitSize, targetReg, maskReg, embMaskOp2Reg,
opt, sopt);
emitInsHelper(targetReg, maskReg, embMaskOp2Reg);
}

GetEmitter()->emitIns_R_R_R_R(INS_sve_sel, emitSize, targetReg, maskReg, targetReg,
Expand All @@ -669,13 +685,13 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)

// Finally, perform the actual "predicated" operation so that `targetReg` is the first operand
// and `embMaskOp2Reg` is the second operand.
GetEmitter()->emitIns_R_R_R(insEmbMask, emitSize, targetReg, maskReg, embMaskOp2Reg, opt, sopt);
emitInsHelper(targetReg, maskReg, embMaskOp2Reg);
}
else
{
// Just perform the actual "predicated" operation so that `targetReg` is the first operand
// and `embMaskOp2Reg` is the second operand.
GetEmitter()->emitIns_R_R_R(insEmbMask, emitSize, targetReg, maskReg, embMaskOp2Reg, opt, sopt);
emitInsHelper(targetReg, maskReg, embMaskOp2Reg);
}

break;
Expand Down
1 change: 1 addition & 0 deletions src/coreclr/jit/hwintrinsiclistarm64sve.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ HARDWARE_INTRINSIC(Sve, SaturatingIncrementByActiveElementCount,
HARDWARE_INTRINSIC(Sve, Scale, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_fscale, INS_sve_fscale}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg|HW_Flag_EmbeddedMaskedOperation|HW_Flag_LowMaskedOperation|HW_Flag_HasRMWSemantics)
HARDWARE_INTRINSIC(Sve, ShiftLeftLogical, -1, -1, false, {INS_sve_lsl, INS_sve_lsl, INS_sve_lsl, INS_sve_lsl, INS_sve_lsl, INS_sve_lsl, INS_sve_lsl, INS_sve_lsl, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg|HW_Flag_EmbeddedMaskedOperation|HW_Flag_LowMaskedOperation|HW_Flag_HasRMWSemantics)
HARDWARE_INTRINSIC(Sve, ShiftRightArithmetic, -1, -1, false, {INS_sve_asr, INS_invalid, INS_sve_asr, INS_invalid, INS_sve_asr, INS_invalid, INS_sve_asr, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg|HW_Flag_EmbeddedMaskedOperation|HW_Flag_LowMaskedOperation|HW_Flag_HasRMWSemantics)
HARDWARE_INTRINSIC(Sve, ShiftRightArithmeticForDivide, -1, -1, false, {INS_sve_asrd, INS_invalid, INS_sve_asrd, INS_invalid, INS_sve_asrd, INS_invalid, INS_sve_asrd, INS_invalid, INS_invalid, INS_invalid}, HW_Category_ShiftRightByImmediate, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_LowMaskedOperation|HW_Flag_HasRMWSemantics|HW_Flag_HasImmediateOperand)
HARDWARE_INTRINSIC(Sve, ShiftRightLogical, -1, -1, false, {INS_invalid, INS_sve_lsr, INS_invalid, INS_sve_lsr, INS_invalid, INS_sve_lsr, INS_invalid, INS_sve_lsr, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg|HW_Flag_EmbeddedMaskedOperation|HW_Flag_LowMaskedOperation|HW_Flag_HasRMWSemantics)
HARDWARE_INTRINSIC(Sve, SignExtend16, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_sxth, INS_invalid, INS_sve_sxth, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_LowMaskedOperation)
HARDWARE_INTRINSIC(Sve, SignExtend32, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_sxtw, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_LowMaskedOperation)
Expand Down
14 changes: 13 additions & 1 deletion src/coreclr/jit/lowerarmarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3371,6 +3371,8 @@ void Lowering::ContainCheckHWIntrinsic(GenTreeHWIntrinsic* node)
// Handle op2
if (op2->OperIsHWIntrinsic())
{
const GenTreeHWIntrinsic* embOp = op2->AsHWIntrinsic();

if (IsInvariantInRange(op2, node) && op2->isEmbeddedMaskingCompatibleHWIntrinsic())
{
uint32_t maskSize = genTypeSize(node->GetSimdBaseType());
Expand All @@ -3386,7 +3388,6 @@ void Lowering::ContainCheckHWIntrinsic(GenTreeHWIntrinsic* node)
{
// Else check if this operation has an auxiliary type that matches the
// mask size.
GenTreeHWIntrinsic* embOp = op2->AsHWIntrinsic();

// For now, make sure that we get here only for intrinsics that we are
// sure about to rely on auxiliary type's size.
Expand All @@ -3403,6 +3404,17 @@ void Lowering::ContainCheckHWIntrinsic(GenTreeHWIntrinsic* node)
}
}
}

// Handle intrinsics with embedded masks and immediate operands
// (For now, just handle ShiftRightArithmeticForDivide specifically)
if (embOp->GetHWIntrinsicId() == NI_Sve_ShiftRightArithmeticForDivide)
{
assert(embOp->GetOperandCount() == 2);
if (embOp->Op(2)->IsCnsIntOrI())
{
MakeSrcContained(op2, embOp->Op(2));
}
}
}

// Handle op3
Expand Down
13 changes: 13 additions & 0 deletions src/coreclr/jit/lsraarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1920,6 +1920,19 @@ int LinearScan::BuildHWIntrinsic(GenTreeHWIntrinsic* intrinsicTree, int* pDstCou
else
{
assert((numArgs == 1) || (numArgs == 2) || (numArgs == 3));

// Special handling for ShiftRightArithmeticForDivide:
// We might need an additional register to hold branch targets into the switch table
// that encodes the immediate
if (intrinEmb.id == NI_Sve_ShiftRightArithmeticForDivide)
{
assert(embOp2Node->GetOperandCount() == 2);
if (!embOp2Node->Op(2)->isContainedIntOrIImmed())
{
buildInternalIntRegisterDefForNode(embOp2Node);
}
}

tgtPrefUse = BuildUse(embOp2Node->Op(1));
srcCount += 1;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6351,6 +6351,45 @@ internal Arm64() { }
public static unsafe Vector<sbyte> ShiftRightArithmetic(Vector<sbyte> left, Vector<ulong> right) { throw new PlatformNotSupportedException(); }


/// Arithmetic shift right for divide by immediate

/// <summary>
/// svint16_t svasrd[_n_s16]_m(svbool_t pg, svint16_t op1, uint64_t imm2)
/// ASRD Ztied1.H, Pg/M, Ztied1.H, #imm2
/// svint16_t svasrd[_n_s16]_x(svbool_t pg, svint16_t op1, uint64_t imm2)
/// ASRD Ztied1.H, Pg/M, Ztied1.H, #imm2
/// svint16_t svasrd[_n_s16]_z(svbool_t pg, svint16_t op1, uint64_t imm2)
/// </summary>
public static unsafe Vector<short> ShiftRightArithmeticForDivide(Vector<short> value, [ConstantExpected(Min = 1, Max = (byte)(16))] byte control) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svint32_t svasrd[_n_s32]_m(svbool_t pg, svint32_t op1, uint64_t imm2)
/// ASRD Ztied1.S, Pg/M, Ztied1.S, #imm2
/// svint32_t svasrd[_n_s32]_x(svbool_t pg, svint32_t op1, uint64_t imm2)
/// ASRD Ztied1.S, Pg/M, Ztied1.S, #imm2
/// svint32_t svasrd[_n_s32]_z(svbool_t pg, svint32_t op1, uint64_t imm2)
/// </summary>
public static unsafe Vector<int> ShiftRightArithmeticForDivide(Vector<int> value, [ConstantExpected(Min = 1, Max = (byte)(32))] byte control) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svint64_t svasrd[_n_s64]_m(svbool_t pg, svint64_t op1, uint64_t imm2)
/// ASRD Ztied1.D, Pg/M, Ztied1.D, #imm2
/// svint64_t svasrd[_n_s64]_x(svbool_t pg, svint64_t op1, uint64_t imm2)
/// ASRD Ztied1.D, Pg/M, Ztied1.D, #imm2
/// svint64_t svasrd[_n_s64]_z(svbool_t pg, svint64_t op1, uint64_t imm2)
/// </summary>
public static unsafe Vector<long> ShiftRightArithmeticForDivide(Vector<long> value, [ConstantExpected(Min = 1, Max = (byte)(64))] byte control) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svint8_t svasrd[_n_s8]_m(svbool_t pg, svint8_t op1, uint64_t imm2)
/// ASRD Ztied1.B, Pg/M, Ztied1.B, #imm2
/// svint8_t svasrd[_n_s8]_x(svbool_t pg, svint8_t op1, uint64_t imm2)
/// ASRD Ztied1.B, Pg/M, Ztied1.B, #imm2
/// svint8_t svasrd[_n_s8]_z(svbool_t pg, svint8_t op1, uint64_t imm2)
/// </summary>
public static unsafe Vector<sbyte> ShiftRightArithmeticForDivide(Vector<sbyte> value, [ConstantExpected(Min = 1, Max = (byte)(8))] byte control) { throw new PlatformNotSupportedException(); }


/// Logical shift right

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6395,6 +6395,45 @@ internal Arm64() { }
public static unsafe Vector<sbyte> ShiftRightArithmetic(Vector<sbyte> left, Vector<ulong> right) => ShiftRightArithmetic(left, right);


/// Arithmetic shift right for divide by immediate

/// <summary>
/// svint16_t svasrd[_n_s16]_m(svbool_t pg, svint16_t op1, uint64_t imm2)
/// ASRD Ztied1.H, Pg/M, Ztied1.H, #imm2
/// svint16_t svasrd[_n_s16]_x(svbool_t pg, svint16_t op1, uint64_t imm2)
/// ASRD Ztied1.H, Pg/M, Ztied1.H, #imm2
/// svint16_t svasrd[_n_s16]_z(svbool_t pg, svint16_t op1, uint64_t imm2)
/// </summary>
public static unsafe Vector<short> ShiftRightArithmeticForDivide(Vector<short> value, [ConstantExpected(Min = 1, Max = (byte)(16))] byte control) => ShiftRightArithmeticForDivide(value, control);

/// <summary>
/// svint32_t svasrd[_n_s32]_m(svbool_t pg, svint32_t op1, uint64_t imm2)
/// ASRD Ztied1.S, Pg/M, Ztied1.S, #imm2
/// svint32_t svasrd[_n_s32]_x(svbool_t pg, svint32_t op1, uint64_t imm2)
/// ASRD Ztied1.S, Pg/M, Ztied1.S, #imm2
/// svint32_t svasrd[_n_s32]_z(svbool_t pg, svint32_t op1, uint64_t imm2)
/// </summary>
public static unsafe Vector<int> ShiftRightArithmeticForDivide(Vector<int> value, [ConstantExpected(Min = 1, Max = (byte)(32))] byte control) => ShiftRightArithmeticForDivide(value, control);

/// <summary>
/// svint64_t svasrd[_n_s64]_m(svbool_t pg, svint64_t op1, uint64_t imm2)
/// ASRD Ztied1.D, Pg/M, Ztied1.D, #imm2
/// svint64_t svasrd[_n_s64]_x(svbool_t pg, svint64_t op1, uint64_t imm2)
/// ASRD Ztied1.D, Pg/M, Ztied1.D, #imm2
/// svint64_t svasrd[_n_s64]_z(svbool_t pg, svint64_t op1, uint64_t imm2)
/// </summary>
public static unsafe Vector<long> ShiftRightArithmeticForDivide(Vector<long> value, [ConstantExpected(Min = 1, Max = (byte)(64))] byte control) => ShiftRightArithmeticForDivide(value, control);

/// <summary>
/// svint8_t svasrd[_n_s8]_m(svbool_t pg, svint8_t op1, uint64_t imm2)
/// ASRD Ztied1.B, Pg/M, Ztied1.B, #imm2
/// svint8_t svasrd[_n_s8]_x(svbool_t pg, svint8_t op1, uint64_t imm2)
/// ASRD Ztied1.B, Pg/M, Ztied1.B, #imm2
/// svint8_t svasrd[_n_s8]_z(svbool_t pg, svint8_t op1, uint64_t imm2)
/// </summary>
public static unsafe Vector<sbyte> ShiftRightArithmeticForDivide(Vector<sbyte> value, [ConstantExpected(Min = 1, Max = (byte)(8))] byte control) => ShiftRightArithmeticForDivide(value, control);


/// Logical shift right

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5118,6 +5118,10 @@ internal Arm64() { }
public static System.Numerics.Vector<long> ShiftRightArithmetic(System.Numerics.Vector<long> left, System.Numerics.Vector<ulong> right) { throw null; }
public static System.Numerics.Vector<sbyte> ShiftRightArithmetic(System.Numerics.Vector<sbyte> left, System.Numerics.Vector<byte> right) { throw null; }
public static System.Numerics.Vector<sbyte> ShiftRightArithmetic(System.Numerics.Vector<sbyte> left, System.Numerics.Vector<ulong> right) { throw null; }
public static System.Numerics.Vector<short> ShiftRightArithmeticForDivide(System.Numerics.Vector<short> value, [ConstantExpected(Min = 1, Max = (byte)(16))] byte control) { throw null; }
public static System.Numerics.Vector<int> ShiftRightArithmeticForDivide(System.Numerics.Vector<int> value, [ConstantExpected(Min = 1, Max = (byte)(32))] byte control) { throw null; }
public static System.Numerics.Vector<long> ShiftRightArithmeticForDivide(System.Numerics.Vector<long> value, [ConstantExpected(Min = 1, Max = (byte)(64))] byte control) { throw null; }
public static System.Numerics.Vector<sbyte> ShiftRightArithmeticForDivide(System.Numerics.Vector<sbyte> value, [ConstantExpected(Min = 1, Max = (byte)(8))] byte control) { throw null; }
public static System.Numerics.Vector<byte> ShiftRightLogical(System.Numerics.Vector<byte> left, System.Numerics.Vector<byte> right) { throw null; }
public static System.Numerics.Vector<byte> ShiftRightLogical(System.Numerics.Vector<byte> left, System.Numerics.Vector<ulong> right) { throw null; }
public static System.Numerics.Vector<ushort> ShiftRightLogical(System.Numerics.Vector<ushort> left, System.Numerics.Vector<ushort> right) { throw null; }
Expand Down
Loading

0 comments on commit 89d63f9

Please sign in to comment.