Skip to content

Commit

Permalink
CodeGen: Extract all vector tag patching into TAG_VECTOR (#1171)
Browse files Browse the repository at this point in the history
Instead of patching the tag component with TVECTOR in every instruction
that produces a vector value, we now use a separate IR instruction to do
this. This reduces implementation redundancy, but more importantly
allows for a class of optimizations:

- NUM_TO_VECTOR previously patched the component unconditionally but the
result was used only in MUL/DIV_VEC instructions that ignore it anyway;
we can now remove this.

- ADD_VEC et al can now forward the source of TAG_VECTOR instruction of
either input; this shortens the latency chain and in the future could
allow us to generate optimal vector instruction sequence once the
temporary stores are marked as dead.

- In the future on X64, ADD_VEC et al will be able to analyze the input
instruction and remove tag masking conditionally. This is not part of
this PR as it requires a decision around expected FP environment and/or
the necessity of the existing masking to begin with.

I've also renamed NUM_TO_VECTOR to NUM_TO_VEC so that "VEC" always
refers to "3 float values" and for consistency with ADD/etc.

Note: ADD_VEC input forwarding is currently performed unconditionally;
it may or may not increase the spills that can't be reloaded from the
stack.

On A64 this makes the Taylor series computation a tiny bit faster
(11.3ns => 11.0ns) as it removes the redundant ins instructions along
the NUM_TO_VEC path. Curiously, the optimization of forwarding
TAG_VECTOR input to arithmetic instructions actually has a small penalty
as without it this PR runs at 10.9 ns. I don't know if this is a
property of the benchmark though, as I just noticed that in this
benchmark type inference actually fails to infer parts of the
computation as a vector op. If desired I will happily omit this part of
the change and we can explore that separately.
  • Loading branch information
zeux authored Feb 21, 2024
1 parent c5f4d97 commit 80928ac
Show file tree
Hide file tree
Showing 9 changed files with 192 additions and 74 deletions.
6 changes: 5 additions & 1 deletion CodeGen/include/Luau/IrData.h
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,11 @@ enum class IrCmd : uint8_t

// Converts a double number to a vector with the value in X/Y/Z
// A: double
NUM_TO_VECTOR,
NUM_TO_VEC,

// Adds VECTOR type tag to a vector, preserving X/Y/Z components
// A: TValue
TAG_VECTOR,

// Adjust stack top (L->top) to point at 'B' TValues *after* the specified register
// This is used to return multiple values
Expand Down
3 changes: 2 additions & 1 deletion CodeGen/include/Luau/IrUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,8 @@ inline bool hasResult(IrCmd cmd)
case IrCmd::UINT_TO_NUM:
case IrCmd::NUM_TO_INT:
case IrCmd::NUM_TO_UINT:
case IrCmd::NUM_TO_VECTOR:
case IrCmd::NUM_TO_VEC:
case IrCmd::TAG_VECTOR:
case IrCmd::SUBSTITUTE:
case IrCmd::INVOKE_FASTCALL:
case IrCmd::BITAND_UINT:
Expand Down
6 changes: 4 additions & 2 deletions CodeGen/src/IrDump.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,10 @@ const char* getCmdName(IrCmd cmd)
return "NUM_TO_INT";
case IrCmd::NUM_TO_UINT:
return "NUM_TO_UINT";
case IrCmd::NUM_TO_VECTOR:
return "NUM_TO_VECTOR";
case IrCmd::NUM_TO_VEC:
return "NUM_TO_VEC";
case IrCmd::TAG_VECTOR:
return "TAG_VECTOR";
case IrCmd::ADJUST_STACK_TO_REG:
return "ADJUST_STACK_TO_REG";
case IrCmd::ADJUST_STACK_TO_TOP:
Expand Down
66 changes: 50 additions & 16 deletions CodeGen/src/IrLoweringA64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauCodeGenFixBufferLenCheckA64, false)
LUAU_FASTFLAGVARIABLE(LuauCodeGenVectorA64, false)

LUAU_FASTFLAG(LuauCodegenVectorTag)

namespace Luau
{
namespace CodeGen
Expand Down Expand Up @@ -678,9 +680,12 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
{
build.fadd(inst.regA64, regOp(inst.a), regOp(inst.b));

RegisterA64 tempw = regs.allocTemp(KindA64::w);
build.mov(tempw, LUA_TVECTOR);
build.ins_4s(inst.regA64, tempw, 3);
if (!FFlag::LuauCodegenVectorTag)
{
RegisterA64 tempw = regs.allocTemp(KindA64::w);
build.mov(tempw, LUA_TVECTOR);
build.ins_4s(inst.regA64, tempw, 3);
}
}
else
{
Expand All @@ -705,9 +710,12 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
{
build.fsub(inst.regA64, regOp(inst.a), regOp(inst.b));

RegisterA64 tempw = regs.allocTemp(KindA64::w);
build.mov(tempw, LUA_TVECTOR);
build.ins_4s(inst.regA64, tempw, 3);
if (!FFlag::LuauCodegenVectorTag)
{
RegisterA64 tempw = regs.allocTemp(KindA64::w);
build.mov(tempw, LUA_TVECTOR);
build.ins_4s(inst.regA64, tempw, 3);
}
}
else
{
Expand All @@ -732,9 +740,12 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
{
build.fmul(inst.regA64, regOp(inst.a), regOp(inst.b));

RegisterA64 tempw = regs.allocTemp(KindA64::w);
build.mov(tempw, LUA_TVECTOR);
build.ins_4s(inst.regA64, tempw, 3);
if (!FFlag::LuauCodegenVectorTag)
{
RegisterA64 tempw = regs.allocTemp(KindA64::w);
build.mov(tempw, LUA_TVECTOR);
build.ins_4s(inst.regA64, tempw, 3);
}
}
else
{
Expand All @@ -759,9 +770,12 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
{
build.fdiv(inst.regA64, regOp(inst.a), regOp(inst.b));

RegisterA64 tempw = regs.allocTemp(KindA64::w);
build.mov(tempw, LUA_TVECTOR);
build.ins_4s(inst.regA64, tempw, 3);
if (!FFlag::LuauCodegenVectorTag)
{
RegisterA64 tempw = regs.allocTemp(KindA64::w);
build.mov(tempw, LUA_TVECTOR);
build.ins_4s(inst.regA64, tempw, 3);
}
}
else
{
Expand All @@ -786,9 +800,12 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
{
build.fneg(inst.regA64, regOp(inst.a));

RegisterA64 tempw = regs.allocTemp(KindA64::w);
build.mov(tempw, LUA_TVECTOR);
build.ins_4s(inst.regA64, tempw, 3);
if (!FFlag::LuauCodegenVectorTag)
{
RegisterA64 tempw = regs.allocTemp(KindA64::w);
build.mov(tempw, LUA_TVECTOR);
build.ins_4s(inst.regA64, tempw, 3);
}
}
else
{
Expand Down Expand Up @@ -1156,7 +1173,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
build.fcvtzs(castReg(KindA64::x, inst.regA64), temp);
break;
}
case IrCmd::NUM_TO_VECTOR:
case IrCmd::NUM_TO_VEC:
{
inst.regA64 = regs.allocReg(KindA64::q, index);

Expand All @@ -1167,6 +1184,23 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
build.fcvt(temps, tempd);
build.dup_4s(inst.regA64, castReg(KindA64::q, temps), 0);

if (!FFlag::LuauCodegenVectorTag)
{
build.mov(tempw, LUA_TVECTOR);
build.ins_4s(inst.regA64, tempw, 3);
}
break;
}
case IrCmd::TAG_VECTOR:
{
inst.regA64 = regs.allocReuse(KindA64::q, index, {inst.a});

RegisterA64 reg = regOp(inst.a);
RegisterA64 tempw = regs.allocTemp(KindA64::w);

if (inst.regA64 != reg)
build.mov(inst.regA64, reg);

build.mov(tempw, LUA_TVECTOR);
build.ins_4s(inst.regA64, tempw, 3);
break;
Expand Down
34 changes: 26 additions & 8 deletions CodeGen/src/IrLoweringX64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
#include "lstate.h"
#include "lgc.h"

LUAU_FASTFLAG(LuauCodegenVectorTag)

namespace Luau
{
namespace CodeGen
Expand Down Expand Up @@ -608,7 +610,9 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
build.vandps(tmp1.reg, regOp(inst.a), vectorAndMaskOp());
build.vandps(tmp2.reg, regOp(inst.b), vectorAndMaskOp());
build.vaddps(inst.regX64, tmp1.reg, tmp2.reg);
build.vorps(inst.regX64, inst.regX64, vectorOrMaskOp());

if (!FFlag::LuauCodegenVectorTag)
build.vorps(inst.regX64, inst.regX64, vectorOrMaskOp());
break;
}
case IrCmd::SUB_VEC:
Expand All @@ -622,7 +626,8 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
build.vandps(tmp1.reg, regOp(inst.a), vectorAndMaskOp());
build.vandps(tmp2.reg, regOp(inst.b), vectorAndMaskOp());
build.vsubps(inst.regX64, tmp1.reg, tmp2.reg);
build.vorps(inst.regX64, inst.regX64, vectorOrMaskOp());
if (!FFlag::LuauCodegenVectorTag)
build.vorps(inst.regX64, inst.regX64, vectorOrMaskOp());
break;
}
case IrCmd::MUL_VEC:
Expand All @@ -636,7 +641,8 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
build.vandps(tmp1.reg, regOp(inst.a), vectorAndMaskOp());
build.vandps(tmp2.reg, regOp(inst.b), vectorAndMaskOp());
build.vmulps(inst.regX64, tmp1.reg, tmp2.reg);
build.vorps(inst.regX64, inst.regX64, vectorOrMaskOp());
if (!FFlag::LuauCodegenVectorTag)
build.vorps(inst.regX64, inst.regX64, vectorOrMaskOp());
break;
}
case IrCmd::DIV_VEC:
Expand All @@ -650,7 +656,8 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
build.vandps(tmp1.reg, regOp(inst.a), vectorAndMaskOp());
build.vandps(tmp2.reg, regOp(inst.b), vectorAndMaskOp());
build.vdivps(inst.regX64, tmp1.reg, tmp2.reg);
build.vpinsrd(inst.regX64, inst.regX64, build.i32(LUA_TVECTOR), 3);
if (!FFlag::LuauCodegenVectorTag)
build.vpinsrd(inst.regX64, inst.regX64, build.i32(LUA_TVECTOR), 3);
break;
}
case IrCmd::UNM_VEC:
Expand All @@ -669,7 +676,8 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
build.vxorpd(inst.regX64, inst.regX64, build.f32x4(-0.0, -0.0, -0.0, -0.0));
}

build.vpinsrd(inst.regX64, inst.regX64, build.i32(LUA_TVECTOR), 3);
if (!FFlag::LuauCodegenVectorTag)
build.vpinsrd(inst.regX64, inst.regX64, build.i32(LUA_TVECTOR), 3);
break;
}
case IrCmd::NOT_ANY:
Expand Down Expand Up @@ -964,7 +972,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)

build.vcvttsd2si(qwordReg(inst.regX64), memRegDoubleOp(inst.a));
break;
case IrCmd::NUM_TO_VECTOR:
case IrCmd::NUM_TO_VEC:
inst.regX64 = regs.allocReg(SizeX64::xmmword, index);

if (inst.a.kind == IrOpKind::Constant)
Expand All @@ -974,15 +982,25 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
static_assert(sizeof(asU32) == sizeof(value), "Expecting float to be 32-bit");
memcpy(&asU32, &value, sizeof(value));

build.vmovaps(inst.regX64, build.u32x4(asU32, asU32, asU32, LUA_TVECTOR));
if (FFlag::LuauCodegenVectorTag)
build.vmovaps(inst.regX64, build.u32x4(asU32, asU32, asU32, 0));
else
build.vmovaps(inst.regX64, build.u32x4(asU32, asU32, asU32, LUA_TVECTOR));
}
else
{
build.vcvtsd2ss(inst.regX64, inst.regX64, memRegDoubleOp(inst.a));
build.vpshufps(inst.regX64, inst.regX64, inst.regX64, 0b00'00'00'00);
build.vpinsrd(inst.regX64, inst.regX64, build.i32(LUA_TVECTOR), 3);

if (!FFlag::LuauCodegenVectorTag)
build.vpinsrd(inst.regX64, inst.regX64, build.i32(LUA_TVECTOR), 3);
}
break;
case IrCmd::TAG_VECTOR:
inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a});

build.vpinsrd(inst.regX64, regOp(inst.a), build.i32(LUA_TVECTOR), 3);
break;
case IrCmd::ADJUST_STACK_TO_REG:
{
ScopedRegX64 tmp{regs, SizeX64::qword};
Expand Down
22 changes: 17 additions & 5 deletions CodeGen/src/IrTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

LUAU_FASTFLAGVARIABLE(LuauCodegenLuData, false)
LUAU_FASTFLAGVARIABLE(LuauCodegenVector, false)
LUAU_FASTFLAGVARIABLE(LuauCodegenVectorTag, false)

namespace Luau
{
Expand Down Expand Up @@ -380,9 +381,12 @@ static void translateInstBinaryNumeric(IrBuilder& build, int ra, int rb, int rc,
result = build.inst(IrCmd::DIV_VEC, vb, vc);
break;
default:
break;
CODEGEN_ASSERT(!"Unknown TM op");
}

if (FFlag::LuauCodegenVectorTag)
result = build.inst(IrCmd::TAG_VECTOR, result);

build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), result);
return;
}
Expand All @@ -393,7 +397,7 @@ static void translateInstBinaryNumeric(IrBuilder& build, int ra, int rb, int rc,

build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(rc)), build.constTag(LUA_TVECTOR), build.vmExit(pcpos));

IrOp vb = build.inst(IrCmd::NUM_TO_VECTOR, loadDoubleOrConstant(build, opb));
IrOp vb = build.inst(IrCmd::NUM_TO_VEC, loadDoubleOrConstant(build, opb));
IrOp vc = build.inst(IrCmd::LOAD_TVALUE, opc);
IrOp result;

Expand All @@ -406,9 +410,12 @@ static void translateInstBinaryNumeric(IrBuilder& build, int ra, int rb, int rc,
result = build.inst(IrCmd::DIV_VEC, vb, vc);
break;
default:
break;
CODEGEN_ASSERT(!"Unknown TM op");
}

if (FFlag::LuauCodegenVectorTag)
result = build.inst(IrCmd::TAG_VECTOR, result);

build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), result);
return;
}
Expand All @@ -420,7 +427,7 @@ static void translateInstBinaryNumeric(IrBuilder& build, int ra, int rb, int rc,
build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(rc)), build.constTag(LUA_TNUMBER), build.vmExit(pcpos));

IrOp vb = build.inst(IrCmd::LOAD_TVALUE, opb);
IrOp vc = build.inst(IrCmd::NUM_TO_VECTOR, loadDoubleOrConstant(build, opc));
IrOp vc = build.inst(IrCmd::NUM_TO_VEC, loadDoubleOrConstant(build, opc));
IrOp result;

switch (tm)
Expand All @@ -432,9 +439,12 @@ static void translateInstBinaryNumeric(IrBuilder& build, int ra, int rb, int rc,
result = build.inst(IrCmd::DIV_VEC, vb, vc);
break;
default:
break;
CODEGEN_ASSERT(!"Unknown TM op");
}

if (FFlag::LuauCodegenVectorTag)
result = build.inst(IrCmd::TAG_VECTOR, result);

build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), result);
return;
}
Expand Down Expand Up @@ -596,6 +606,8 @@ void translateInstMinus(IrBuilder& build, const Instruction* pc, int pcpos)

IrOp vb = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(rb));
IrOp va = build.inst(IrCmd::UNM_VEC, vb);
if (FFlag::LuauCodegenVectorTag)
va = build.inst(IrCmd::TAG_VECTOR, va);
build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), va);
return;
}
Expand Down
3 changes: 2 additions & 1 deletion CodeGen/src/IrUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ IrValueKind getCmdValueKind(IrCmd cmd)
case IrCmd::NUM_TO_INT:
case IrCmd::NUM_TO_UINT:
return IrValueKind::Int;
case IrCmd::NUM_TO_VECTOR:
case IrCmd::NUM_TO_VEC:
case IrCmd::TAG_VECTOR:
return IrValueKind::Tvalue;
case IrCmd::ADJUST_STACK_TO_REG:
case IrCmd::ADJUST_STACK_TO_TOP:
Expand Down
Loading

0 comments on commit 80928ac

Please sign in to comment.