Skip to content

Commit

Permalink
CodeGen: Optimize arithmetics for basic identities (#1545)
Browse files Browse the repository at this point in the history
This change folds:

	a * 1 => a
	a / 1 => a
	a * -1 => -a
	a / -1 => -a
	a * 2 => a + a
	a / 2^k => a * 2^-k
	a - 0 => a
	a + (-0) => a

Note that the following folds are all invalid:

	a + 0 => a (breaks for negative zero)
	a - (-0) => a (breaks for negative zero)
	a - a => 0 (breaks for Inf/NaN)
	0 - a => -a (breaks for negative zero)

Various cases of UNM_NUM could be optimized (eg (-a) * (-b) = a * b),
but that doesn't happen in benchmarks.

While it would be possible to also fold inverse multiplications (k * v),
these do not happen in benchmarks and rarely happen in bytecode due
to type based optimizations. Maybe this can be improved with some sort
of
IR canonicalization in the future if necessary.

I've considered moving some of these, like division strength reduction,
to IR translation (as this is where POW is lowered presently) but it
didn't
seem better one way or the other.

This change improves performance on some benchmarks, e.g. trig and
voxelgen,
and should be a strict uplift as it never generates more instructions or
longer
latency chains. On Apple M2, without division->multiplication
optimization, both
benchmarks see 0.1-0.2% uplift. Division optimization makes trig 3%
faster; I expect
the gains on X64 will be more muted, but on Apple this seems to allow
loop iterations
to overlap better by removing the division bottleneck.
  • Loading branch information
zeux authored Nov 27, 2024
1 parent d19a5f0 commit b5801d3
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 6 deletions.
59 changes: 59 additions & 0 deletions CodeGen/src/OptimizeConstProp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "lua.h"

#include <limits.h>
#include <math.h>

#include <array>
#include <utility>
Expand All @@ -19,6 +20,7 @@ LUAU_FASTINTVARIABLE(LuauCodeGenReuseSlotLimit, 64)
LUAU_FASTINTVARIABLE(LuauCodeGenReuseUdataTagLimit, 64)
LUAU_FASTFLAGVARIABLE(DebugLuauAbortingChecks)
LUAU_FASTFLAG(LuauVectorLibNativeDot);
LUAU_FASTFLAGVARIABLE(LuauCodeGenArithOpt);

namespace Luau
{
Expand Down Expand Up @@ -1192,10 +1194,67 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction&
break;
case IrCmd::ADD_INT:
case IrCmd::SUB_INT:
state.substituteOrRecord(inst, index);
break;
case IrCmd::ADD_NUM:
case IrCmd::SUB_NUM:
if (FFlag::LuauCodeGenArithOpt)
{
if (std::optional<double> k = function.asDoubleOp(inst.b.kind == IrOpKind::Constant ? inst.b : state.tryGetValue(inst.b)))
{
// a + 0.0 and a - (-0.0) can't be folded since the behavior is different for negative zero
// however, a - 0.0 and a + (-0.0) can be folded into a
if (*k == 0.0 && bool(signbit(*k)) == (inst.cmd == IrCmd::ADD_NUM))
substitute(function, inst, inst.a);
else
state.substituteOrRecord(inst, index);
}
else
state.substituteOrRecord(inst, index);
}
else
state.substituteOrRecord(inst, index);
break;
case IrCmd::MUL_NUM:
if (FFlag::LuauCodeGenArithOpt)
{
if (std::optional<double> k = function.asDoubleOp(inst.b.kind == IrOpKind::Constant ? inst.b : state.tryGetValue(inst.b)))
{
if (*k == 1.0) // a * 1.0 = a
substitute(function, inst, inst.a);
else if (*k == 2.0) // a * 2.0 = a + a
replace(function, block, index, {IrCmd::ADD_NUM, inst.a, inst.a});
else if (*k == -1.0) // a * -1.0 = -a
replace(function, block, index, {IrCmd::UNM_NUM, inst.a});
else
state.substituteOrRecord(inst, index);
}
else
state.substituteOrRecord(inst, index);
}
else
state.substituteOrRecord(inst, index);
break;
case IrCmd::DIV_NUM:
if (FFlag::LuauCodeGenArithOpt)
{
if (std::optional<double> k = function.asDoubleOp(inst.b.kind == IrOpKind::Constant ? inst.b : state.tryGetValue(inst.b)))
{
if (*k == 1.0) // a / 1.0 = a
substitute(function, inst, inst.a);
else if (*k == -1.0) // a / -1.0 = -a
replace(function, block, index, {IrCmd::UNM_NUM, inst.a});
else if (int exp = 0; frexp(*k, &exp) == 0.5 && exp >= -1000 && exp <= 1000) // a / 2^k = a * 2^-k
replace(function, block, index, {IrCmd::MUL_NUM, inst.a, build.constDouble(1.0 / *k)});
else
state.substituteOrRecord(inst, index);
}
else
state.substituteOrRecord(inst, index);
}
else
state.substituteOrRecord(inst, index);
break;
case IrCmd::IDIV_NUM:
case IrCmd::MOD_NUM:
case IrCmd::MIN_NUM:
Expand Down
8 changes: 4 additions & 4 deletions tests/IrLowering.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ TEST_CASE("VectorCustomAccess")
CHECK_EQ(
"\n" + getCodegenAssembly(R"(
local function vec3magn(a: vector)
return a.Magnitude * 2
return a.Magnitude * 3
end
)"),
R"(
Expand All @@ -560,7 +560,7 @@ end
%12 = ADD_NUM %9, %10
%13 = ADD_NUM %12, %11
%14 = SQRT_NUM %13
%20 = MUL_NUM %14, 2
%20 = MUL_NUM %14, 3
STORE_DOUBLE R1, %20
STORE_TAG R1, tnumber
INTERRUPT 3u
Expand Down Expand Up @@ -1167,7 +1167,7 @@ local function inl(v: vector, s: number)
end
local function getsum(x)
return inl(x, 2) + inl(x, 5)
return inl(x, 3) + inl(x, 5)
end
)",
/* includeIrTypes */ true
Expand Down Expand Up @@ -1195,7 +1195,7 @@ end
bb_bytecode_0:
CHECK_TAG R0, tvector, exit(0)
%2 = LOAD_FLOAT R0, 4i
%8 = MUL_NUM %2, 2
%8 = MUL_NUM %2, 3
%13 = LOAD_FLOAT R0, 4i
%19 = MUL_NUM %13, 5
%28 = ADD_NUM %8, %19
Expand Down
14 changes: 12 additions & 2 deletions tests/conformance/basic.lua
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,16 @@ assert((function() local a = 1 a = a - 2 return a end)() == -1)
assert((function() local a = 1 a = a * 2 return a end)() == 2)
assert((function() local a = 1 a = a / 2 return a end)() == 0.5)

-- binary ops with fp specials, neg zero, large constants
-- argument is passed into anonymous function to prevent constant folding
assert((function(a) return tostring(a + 0) end)(-0) == "0")
assert((function(a) return tostring(a - 0) end)(-0) == "-0")
assert((function(a) return tostring(0 - a) end)(0) == "0")
assert((function(a) return tostring(a - a) end)(1 / 0) == "nan")
assert((function(a) return tostring(a * 0) end)(0 / 0) == "nan")
assert((function(a) return tostring(a / (2^1000)) end)(2^1000) == "1")
assert((function(a) return tostring(a / (2^-1000)) end)(2^-1000) == "1")

-- floor division should always round towards -Infinity
assert((function() local a = 1 a = a // 2 return a end)() == 0)
assert((function() local a = 3 a = a // 2 return a end)() == 1)
Expand Down Expand Up @@ -290,7 +300,7 @@ assert((function() local t = {[1] = 1, [2] = 2} return t[1] + t[2] end)() == 3)
assert((function() return table.concat({}, ',') end)() == "")
assert((function() return table.concat({1}, ',') end)() == "1")
assert((function() return table.concat({1,2}, ',') end)() == "1,2")
assert((function() return table.concat({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, ',') end)() ==
assert((function() return table.concat({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, ',') end)() ==
"1,2,3,4,5,6,7,8,9,10,11,12,13,14,15")
assert((function() return table.concat({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, ',') end)() == "1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16")
assert((function() return table.concat({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}, ',') end)() == "1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17")
Expand Down Expand Up @@ -770,7 +780,7 @@ assert(tostring(0) == "0")
assert(tostring(-0) == "-0")

-- test newline handling in long strings
assert((function()
assert((function()
local s1 = [[
]]
local s2 = [[
Expand Down

0 comments on commit b5801d3

Please sign in to comment.