Skip to content

Commit

Permalink
[inductor][cpp] improve vector contiguous checks for FloorDiv and Mod…
Browse files Browse the repository at this point in the history
…ularIndexing (pytorch#117221)

Fix pytorch#114488

The PR tries to enable contiguous vector loads for cases where we can reduce `FloorDiv` and `ModularIndexing` in the vectorized loop.

Take the index expression in test case `test_vec_contiguous_ModularIndexing` for example.
`14336*x0 + 256*x1 + 128*((x2//256)) + ModularIndexing(x2, 1, 128) + 7168*ModularIndexing(x2, 128, 2)` can be reduced to `14336*x0 + 256*x1 + x2 + 128*x2_div_c0 + 7168*x2_mod_c0 + x2_mod_c1` where `x2` is a vectorized loop variable and the vector length is 16. This means we can do vectorized load for this index. Check the code comment for more details:
https://github.com/pytorch/pytorch/pull/117221/files#diff-5ab7b0235e2076a5fc6629ba0b109208940f5b94f5c13babc3e0f87cf4fcec82R317-R329

Pull Request resolved: pytorch#117221
Approved by: https://github.com/jansel
  • Loading branch information
jgong5 authored and pytorchmergebot committed Jan 12, 2024
1 parent 6c624aa commit 172dd13
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 24 deletions.
39 changes: 38 additions & 1 deletion test/inductor/test_cpu_repro.py
Original file line number Diff line number Diff line change
Expand Up @@ -2724,10 +2724,47 @@ def fn(x):
return y.softmax(dim=-1)

x = torch.randn(128, 2048)
opt_fn = torch.compile(fn)
metrics.reset()
self.common(fn, (x,))
_, code = run_and_get_cpp_code(opt_fn, x)
self.assertTrue(same(fn(x), opt_fn(x)))
# 4 kernels for max, exp, sum and div
assert metrics.generated_cpp_vec_kernel_count == 4
FileCheck().check_count(
"Vectorized<int>::loadu(tmpbuf.data())", 0, exactly=True
).run(code)

def test_vec_contiguous_ModularIndexing(self):
# https://github.com/pytorch/pytorch/issues/114488
class M(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.norm = torch.nn.LayerNorm(dim * 4)

def forward(self, x):
# the pattern from swin_base_patch4_window7_224
B, H, W, C = x.shape
x = (
x.reshape(B, H // 2, 2, W // 2, 2, C)
.permute(0, 1, 3, 4, 2, 5)
.flatten(3)
)
x = self.norm(x)
return x

x = torch.randn(1, 56, 56, 128)
m = M(128)
opt_m = torch.compile(m)
with torch.no_grad():
metrics.reset()
_, code = run_and_get_cpp_code(opt_m, x)
self.assertTrue(same(m(x), opt_m(x)))
# Two kernels: one for reduction, one pointwises
assert metrics.generated_cpp_vec_kernel_count == 2
# Only one kernel has non-contiguous load
FileCheck().check_count(
"Vectorized<float>::loadu(tmpbuf.data())", 1, exactly=True
).run(code)


if __name__ == "__main__":
Expand Down
104 changes: 81 additions & 23 deletions torch/_inductor/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torch._inductor import dependencies
from torch._inductor.ir import StorageBox, TensorBox
from torch._prims_common import is_float_dtype
from torch.utils._sympy.functions import FloorDiv
from torch.utils._sympy.functions import FloorDiv, ModularIndexing
from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges

from .. import codecache, config, ir, metrics
Expand Down Expand Up @@ -311,6 +311,69 @@ def stride_at(var: sympy.Symbol, index: sympy.Expr):
return sympy.simplify(new_index - index)


@functools.lru_cache
def simplify_index_in_vec_range(index: sympy.Expr, var: sympy.Expr, vec_length: int):
"""
Simplifies the index expression within the range of a vectorized loop.
Given a vectorized loop variable `var` in the range of a loop with `vec_length`,
this function transforms the `index` into an equivalent form. It handles
simplifications for cases where `var` can be expressed as `vec_length * a + b`,
where `b` ranges from 0 to `vec_length - 1`. The function reduces occurrences
of `FloorDiv` and `ModularIndexing` in the `index` with best-effort optimizations.
NOTE:
The simplified index expression is intended for analysis purposes only, not
for code generation. It replaces `FloorDiv` and `ModularIndexing` with free variables
which are not dependent on the loop variable `var` in the vectorized range. Check
https://github.com/pytorch/pytorch/pull/117221#discussion_r1449746217 for more details.
Examples:
1. If `var` is `x3` and `vec_length` is 16, and `x3 = 16*a + b`, then
`FloorDiv(x3, div)` or `ModularIndexing(x3, div, mod)` becomes a free variable
when `div` is divisible by 16.
2. `ModularIndexing(x3, 1, mod)` can be simplified to `x3 + c` where `c` is a free
variable when `mod` is divisible by 16.
"""

div_freevar_id = 0
mod_freevar_id = 0

def visit_indexing_div(divisor):
nonlocal div_freevar_id
result = FloorDiv(var, divisor)
if sympy.gcd(divisor, vec_length) == vec_length:
result = sympy.Symbol(f"{var}_div_c{div_freevar_id}")
div_freevar_id += 1
return result

def visit_modular_indexing(divisor, modulus):
nonlocal mod_freevar_id
result = ModularIndexing(var, divisor, modulus)
if sympy.gcd(divisor, vec_length) == vec_length:
result = sympy.Symbol(f"{var}_mod_c{mod_freevar_id}")
mod_freevar_id += 1
elif divisor == 1 and sympy.gcd(modulus, vec_length) == vec_length:
result = var + sympy.Symbol(f"{var}_mod_c{mod_freevar_id}")
mod_freevar_id += 1
return result

original_index = index

div = sympy.Wild("divisor")
if index.has(FloorDiv):
index = index.replace(FloorDiv(var, div), visit_indexing_div)

mod = sympy.Wild("modulus")
if index.has(ModularIndexing):
index = index.replace(ModularIndexing(var, div, mod), visit_modular_indexing)

index = sympy.simplify(index)
if index != original_index:
return simplify_index_in_vec_range(index, var, vec_length)

return index


class CppPrinter(ExprPrinter):
def _print_Integer(self, expr):
return f"{int(expr)}L"
Expand Down Expand Up @@ -1355,12 +1418,15 @@ def index_expr(expr, dtype):
assert isinstance(V.kernel, CppVecKernel)
index = V.kernel.rename_indexing(expr)
tiling_var = V.kernel.itervars[V.kernel.tiling_idx]
if V.kernel.index_is_vector_invariant(index):
return CppOverrides.index_expr(expr, dtype)
if stride_at(
tiling_var, index
).is_number and not V.kernel.index_indirect_depends_on(index, tiling_var):
stride = stride_at(tiling_var, index)
index_vec_simplified = simplify_index_in_vec_range(
index, tiling_var, V.kernel.tiling_factor
)
stride = stride_at(tiling_var, index_vec_simplified)
if stride.is_number and not V.kernel.index_indirect_depends_on(
index, tiling_var
):
if stride == 0:
return CppOverrides.index_expr(expr, dtype)
value = ops.to_dtype(cexpr(index), dtype)
if isinstance(value, OpsValue):
value = value.value
Expand Down Expand Up @@ -1711,18 +1777,6 @@ def __init__(
self.tiling_idx = tiling_idx
metrics.generated_cpp_vec_kernel_count += 1

def index_is_vector_invariant(self, index: sympy.Expr):
"""`index` is either independent from the tiling itervar or unchanged in the vector range"""
tiling_var = self.itervars[self.tiling_idx]
if not self.index_depends_on(index, tiling_var):
return True
if not self.index_indirect_depends_on(index, tiling_var):
vec_range = [
sympy_subs(index, {tiling_var: i}) for i in range(self.tiling_factor)
]
return all(expr == vec_range[0] for expr in vec_range)
return False

def _get_vec_load_line(
self,
var: str,
Expand Down Expand Up @@ -1883,12 +1937,16 @@ def load(self, name: str, index: sympy.Expr):
index = self.rename_indexing(index)
dtype = V.graph.get_dtype(name)
tiling_var = self.itervars[self.tiling_idx]
if self.index_is_vector_invariant(index):
index_vec_simplified = simplify_index_in_vec_range(
index, tiling_var, self.tiling_factor
)
stride = stride_at(tiling_var, index_vec_simplified)
if stride == 0:
# load scalar and lazily broadcast it on demand
return super().load(name, index)
non_contiguous = stride_at(
tiling_var, index
) != 1 or self.index_indirect_depends_on(index, tiling_var)
non_contiguous = stride != 1 or self.index_indirect_depends_on(
index, tiling_var
)
if non_contiguous:
csevar = self.load_non_contiguous(var, index, dtype)
else:
Expand Down

0 comments on commit 172dd13

Please sign in to comment.