diff --git a/CodeGen/include/Luau/AssemblyBuilderX64.h b/CodeGen/include/Luau/AssemblyBuilderX64.h index 30790ee5c..ca5fa7a9f 100644 --- a/CodeGen/include/Luau/AssemblyBuilderX64.h +++ b/CodeGen/include/Luau/AssemblyBuilderX64.h @@ -160,6 +160,7 @@ class AssemblyBuilderX64 void vmaxsd(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vminsd(OperandX64 dst, OperandX64 src1, OperandX64 src2); + void vcmpeqsd(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vcmpltsd(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vblendvpd(RegisterX64 dst, RegisterX64 src1, OperandX64 mask, RegisterX64 src3); diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index 779fe0120..44f2495b6 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -185,6 +185,11 @@ enum class IrCmd : uint8_t // A: double SIGN_NUM, + // Select B if C == D, otherwise select A + // A, B: double (endpoints) + // C, D: double (condition arguments) + SELECT_NUM, + // Add/Sub/Mul/Div/Idiv two vectors // A, B: TValue ADD_VEC, diff --git a/CodeGen/include/Luau/IrUtils.h b/CodeGen/include/Luau/IrUtils.h index 773b23a66..1afa1a34c 100644 --- a/CodeGen/include/Luau/IrUtils.h +++ b/CodeGen/include/Luau/IrUtils.h @@ -174,6 +174,7 @@ inline bool hasResult(IrCmd cmd) case IrCmd::SQRT_NUM: case IrCmd::ABS_NUM: case IrCmd::SIGN_NUM: + case IrCmd::SELECT_NUM: case IrCmd::ADD_VEC: case IrCmd::SUB_VEC: case IrCmd::MUL_VEC: diff --git a/CodeGen/src/AssemblyBuilderX64.cpp b/CodeGen/src/AssemblyBuilderX64.cpp index 803732e2c..1fb1b671f 100644 --- a/CodeGen/src/AssemblyBuilderX64.cpp +++ b/CodeGen/src/AssemblyBuilderX64.cpp @@ -927,6 +927,11 @@ void AssemblyBuilderX64::vminsd(OperandX64 dst, OperandX64 src1, OperandX64 src2 placeAvx("vminsd", dst, src1, src2, 0x5d, false, AVX_0F, AVX_F2); } +void AssemblyBuilderX64::vcmpeqsd(OperandX64 dst, OperandX64 src1, OperandX64 src2) +{ + placeAvx("vcmpeqsd", dst, src1, src2, 0x00, 0xc2, false, AVX_0F, AVX_F2); +} + void AssemblyBuilderX64::vcmpltsd(OperandX64 dst, OperandX64 src1, OperandX64 src2) { placeAvx("vcmpltsd", dst, src1, src2, 0x01, 0xc2, false, AVX_0F, AVX_F2); diff --git a/CodeGen/src/IrDump.cpp b/CodeGen/src/IrDump.cpp index fe6a2397c..dcc9d879e 100644 --- a/CodeGen/src/IrDump.cpp +++ b/CodeGen/src/IrDump.cpp @@ -169,6 +169,8 @@ const char* getCmdName(IrCmd cmd) return "ABS_NUM"; case IrCmd::SIGN_NUM: return "SIGN_NUM"; + case IrCmd::SELECT_NUM: + return "SELECT_NUM"; case IrCmd::ADD_VEC: return "ADD_VEC"; case IrCmd::SUB_VEC: diff --git a/CodeGen/src/IrLoweringA64.cpp b/CodeGen/src/IrLoweringA64.cpp index c7fcac271..972519482 100644 --- a/CodeGen/src/IrLoweringA64.cpp +++ b/CodeGen/src/IrLoweringA64.cpp @@ -13,6 +13,7 @@ LUAU_FASTFLAG(LuauVectorLibNativeDot) LUAU_FASTFLAG(LuauCodeGenVectorDeadStoreElim) +LUAU_FASTFLAG(LuauCodeGenLerp) namespace Luau { @@ -703,6 +704,20 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) build.fcsel(inst.regA64, temp1, inst.regA64, getConditionFP(IrCondition::Less)); break; } + case IrCmd::SELECT_NUM: + { + LUAU_ASSERT(FFlag::LuauCodeGenLerp); + inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a, inst.b}); + + RegisterA64 temp1 = tempDouble(inst.a); + RegisterA64 temp2 = tempDouble(inst.b); + RegisterA64 temp3 = tempDouble(inst.c); + RegisterA64 temp4 = tempDouble(inst.d); + + build.fcmp(temp3, temp4); + build.fcsel(inst.regA64, temp2, temp1, getConditionFP(IrCondition::Equal)); + break; + } case IrCmd::ADD_VEC: { inst.regA64 = regs.allocReuse(KindA64::q, index, {inst.a, inst.b}); diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index 814c6d8cf..2017ce573 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -17,6 +17,7 @@ LUAU_FASTFLAG(LuauVectorLibNativeDot) LUAU_FASTFLAG(LuauCodeGenVectorDeadStoreElim) +LUAU_FASTFLAG(LuauCodeGenLerp) namespace Luau { @@ -622,6 +623,30 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) build.vblendvpd(inst.regX64, tmp1.reg, build.f64x2(1, 1), inst.regX64); break; } + case IrCmd::SELECT_NUM: + { + LUAU_ASSERT(FFlag::LuauCodeGenLerp); + inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a}); // can't reuse b if a is a memory operand + + ScopedRegX64 tmp{regs, SizeX64::xmmword}; + + if (inst.c.kind == IrOpKind::Inst) + build.vcmpeqsd(tmp.reg, regOp(inst.c), memRegDoubleOp(inst.d)); + else + { + build.vmovsd(tmp.reg, memRegDoubleOp(inst.c)); + build.vcmpeqsd(tmp.reg, tmp.reg, memRegDoubleOp(inst.d)); + } + + if (inst.a.kind == IrOpKind::Inst) + build.vblendvpd(inst.regX64, regOp(inst.a), memRegDoubleOp(inst.b), tmp.reg); + else + { + build.vmovsd(inst.regX64, memRegDoubleOp(inst.a)); + build.vblendvpd(inst.regX64, inst.regX64, memRegDoubleOp(inst.b), tmp.reg); + } + break; + } case IrCmd::ADD_VEC: { inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.b}); diff --git a/CodeGen/src/IrTranslateBuiltins.cpp b/CodeGen/src/IrTranslateBuiltins.cpp index ebded5227..fd247b5bb 100644 --- a/CodeGen/src/IrTranslateBuiltins.cpp +++ b/CodeGen/src/IrTranslateBuiltins.cpp @@ -15,6 +15,7 @@ static const int kBit32BinaryOpUnrolledParams = 5; LUAU_FASTFLAGVARIABLE(LuauVectorLibNativeCodegen); LUAU_FASTFLAGVARIABLE(LuauVectorLibNativeDot); +LUAU_FASTFLAGVARIABLE(LuauCodeGenLerp); namespace Luau { @@ -284,6 +285,42 @@ static BuiltinImplResult translateBuiltinMathClamp( return {BuiltinImplType::UsesFallback, 1}; } +static BuiltinImplResult translateBuiltinMathLerp( + IrBuilder& build, + int nparams, + int ra, + int arg, + IrOp args, + IrOp arg3, + int nresults, + IrOp fallback, + int pcpos +) +{ + LUAU_ASSERT(FFlag::LuauCodeGenLerp); + + if (nparams < 3 || nresults > 1) + return {BuiltinImplType::None, -1}; + + builtinCheckDouble(build, build.vmReg(arg), pcpos); + builtinCheckDouble(build, args, pcpos); + builtinCheckDouble(build, arg3, pcpos); + + IrOp a = builtinLoadDouble(build, build.vmReg(arg)); + IrOp b = builtinLoadDouble(build, args); + IrOp t = builtinLoadDouble(build, arg3); + + IrOp l = build.inst(IrCmd::ADD_NUM, a, build.inst(IrCmd::MUL_NUM, build.inst(IrCmd::SUB_NUM, b, a), t)); + IrOp r = build.inst(IrCmd::SELECT_NUM, l, b, t, build.constDouble(1.0), build.cond(IrCondition::Equal)); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), r); + + if (ra != arg) + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); + + return {BuiltinImplType::Full, 1}; +} + static BuiltinImplResult translateBuiltinMathUnary(IrBuilder& build, IrCmd cmd, int nparams, int ra, int arg, int nresults, int pcpos) { if (nparams < 1 || nresults > 1) @@ -1387,6 +1424,8 @@ BuiltinImplResult translateBuiltin( case LBF_VECTOR_MAX: return FFlag::LuauVectorLibNativeCodegen ? translateBuiltinVectorMap2(build, IrCmd::MAX_NUM, nparams, ra, arg, args, arg3, nresults, pcpos) : noneResult; + case LBF_MATH_LERP: + return FFlag::LuauCodeGenLerp ? translateBuiltinMathLerp(build, nparams, ra, arg, args, arg3, nresults, fallback, pcpos) : noneResult; default: return {BuiltinImplType::None, -1}; } diff --git a/CodeGen/src/IrUtils.cpp b/CodeGen/src/IrUtils.cpp index 5f3848071..74bbc6d78 100644 --- a/CodeGen/src/IrUtils.cpp +++ b/CodeGen/src/IrUtils.cpp @@ -13,6 +13,7 @@ #include LUAU_FASTFLAG(LuauVectorLibNativeDot); +LUAU_FASTFLAG(LuauCodeGenLerp); namespace Luau { @@ -70,6 +71,7 @@ IrValueKind getCmdValueKind(IrCmd cmd) case IrCmd::SQRT_NUM: case IrCmd::ABS_NUM: case IrCmd::SIGN_NUM: + case IrCmd::SELECT_NUM: return IrValueKind::Double; case IrCmd::ADD_VEC: case IrCmd::SUB_VEC: @@ -656,6 +658,16 @@ void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint3 substitute(function, inst, build.constDouble(v > 0.0 ? 1.0 : v < 0.0 ? -1.0 : 0.0)); } break; + case IrCmd::SELECT_NUM: + LUAU_ASSERT(FFlag::LuauCodeGenLerp); + if (inst.c.kind == IrOpKind::Constant && inst.d.kind == IrOpKind::Constant) + { + double c = function.doubleOp(inst.c); + double d = function.doubleOp(inst.d); + + substitute(function, inst, c == d ? inst.b : inst.a); + } + break; case IrCmd::NOT_ANY: if (inst.a.kind == IrOpKind::Constant) { diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index 5920f7cc5..ce44f5d10 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -1382,6 +1382,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::SQRT_NUM: case IrCmd::ABS_NUM: case IrCmd::SIGN_NUM: + case IrCmd::SELECT_NUM: case IrCmd::NOT_ANY: state.substituteOrRecord(inst, index); break; diff --git a/tests/AssemblyBuilderX64.test.cpp b/tests/AssemblyBuilderX64.test.cpp index 504e40e4c..fd1deccf4 100644 --- a/tests/AssemblyBuilderX64.test.cpp +++ b/tests/AssemblyBuilderX64.test.cpp @@ -506,6 +506,7 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXBinaryInstructionForms") SINGLE_COMPARE(vmaxsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2b, 0x5f, 0xc6); SINGLE_COMPARE(vminsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2b, 0x5d, 0xc6); + SINGLE_COMPARE(vcmpeqsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2b, 0xc2, 0xc6, 0x00); SINGLE_COMPARE(vcmpltsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2b, 0xc2, 0xc6, 0x01); } diff --git a/tests/conformance/math.lua b/tests/conformance/math.lua index fbd8f9ddb..586023edb 100644 --- a/tests/conformance/math.lua +++ b/tests/conformance/math.lua @@ -408,6 +408,7 @@ assert(math.lerp(1, 5, 1) == 5) assert(math.lerp(1, 5, 0.5) == 3) assert(math.lerp(1, 5, 1.5) == 7) assert(math.lerp(1, 5, -0.5) == -1) +assert(math.lerp(1, 5, noinline(0.5)) == 3) -- lerp properties local sq2, sq3 = math.sqrt(2), math.sqrt(3)