Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TKW] Detect contiguous when lowering mapping #232

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 169 additions & 23 deletions iree/turbine/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,20 @@ def bind_node_proxies(self, node: fx.Node, proxies: List[IRProxyValue]):
)
self._node_values[node] = proxies

def get_induction_vars_and_syms(self) -> tuple[list[OpResult], list[IndexExpr]]:
induction_var_syms = []
induction_vars = []
if self.induction_vars:
for constraint in self.constraints:
if isinstance(constraint, TilingConstraint):
assert (
constraint.dim in self.induction_vars
), f"Could not find induction var for {constraint.dim} dimension"
induction_var_syms.append(constraint.induction_var)
induction_vars.append(self.induction_vars[constraint.dim])

return induction_vars, induction_var_syms


def get_type_or_element_type(operand_type: IrType):
assert isinstance(operand_type, IrType)
Expand All @@ -184,16 +198,7 @@ def get_type_or_element_type(operand_type: IrType):


def add_emitter_subs(emitter: WaveEmitter) -> dict[IndexSymbol, Any]:
induction_var_syms = []
induction_vars = []
if emitter.induction_vars:
for constraint in emitter.constraints:
if isinstance(constraint, TilingConstraint):
assert (
constraint.dim in emitter.induction_vars
), f"Could not find induction var for {constraint.dim} dimension"
induction_var_syms.append(constraint.induction_var)
induction_vars.append(emitter.induction_vars[constraint.dim])
induction_vars, induction_var_syms = emitter.get_induction_vars_and_syms()

# TODO: factor this out
all_symbols = emitter.thread_ids + emitter.workgroup_ids + induction_vars
Expand Down Expand Up @@ -565,6 +570,117 @@ def _build_mask(
return mask


def _simplify_sympy_expr(expr: IndexExpr) -> IndexExpr:
def check_mul(mul):
ret = None
for arg in mul.args:
if arg.is_number:
if ret is not None:
return None

ret = arg
continue

if not isinstance(arg, (sympy.floor, sympy.Mod)):
return None

return ret

def transform_mod(expr):
if not isinstance(expr, sympy.Mod):
return None

p, q = expr.args
if not q.is_number:
return None

if not isinstance(p, sympy.Add):
return None

c = None
terms = []
mult = None
for arg in p.args:
if arg.is_number:
if c is not None:
return None

c = arg
continue

if not isinstance(arg, sympy.Mul):
return None

m = check_mul(arg)
if (m is None) or (q % m != 0):
return None

mult = m if (mult is None) or (m < mult) else mult
terms.append(arg)

if c >= mult:
return None

return (sum(terms) % q) + c

def check_mul_rational(mul):
ret = None
for arg in mul.args:
if isinstance(arg, sympy.Rational):
if ret is not None:
return None

ret = arg
continue

if not isinstance(arg, (sympy.floor, sympy.Mod)):
return None

return ret

def transform_floor(expr):
if not isinstance(expr, sympy.floor):
return None

expr = expr.args[0]
if not isinstance(expr, sympy.Add):
return None

c = None
for arg in expr.args:
if isinstance(arg, sympy.Rational):
if c is not None:
return None

c = arg

if c is None:
return None

terms = []
for arg in expr.args:
if isinstance(arg, sympy.Rational):
continue

if not isinstance(arg, sympy.Mul):
return None

r = check_mul_rational(arg)
if r is None:
return None

if r < c:
return None

terms.append(arg)

return sympy.floor(sum(terms))

expr = expr.replace(lambda e: transform_mod(e) is not None, transform_mod)
expr = expr.replace(lambda e: transform_floor(e) is not None, transform_floor)
return sympy.simplify(expr)


def _construct_gather_scatter_indices(
emitter: WaveEmitter,
symbolc_shape: tuple[IndexExpr],
Expand Down Expand Up @@ -594,9 +710,9 @@ def _construct_gather_scatter_indices(

# As we only support identity input/output mapping for now, we can directly
# substitute iterators with corresponding expanded index.
subs = [
subs = list(idxc.subs.items()) + [
(sym, expr.start) for sym, expr in zip(iters.keys(), index.values())
] + list(idxc.subs.items())
]

# Contruct input/output index, substituting iterators in input mapping with
# expanded index.
Expand All @@ -606,6 +722,38 @@ def _construct_gather_scatter_indices(
offsets = []

start_indices = _get_start_indices(result_index)

expected_diff = [0] * len(start_indices)
expected_diff[-1] = 1
is_contiguous = True
subs[-1] = (subs[-1][0], (subs[-1][1] // elements_per_thread) * elements_per_thread)
prev_indices = _get_start_indices(
{key: m.subs(subs) for key, m in zip(symbolc_shape, index_mapping)}
)
for i in range(1, elements_per_thread, 1):
subs[-1] = (subs[-1][0], subs[-1][1] + 1)
next_result_index = {
key: m.subs(subs) for key, m in zip(symbolc_shape, index_mapping)
}
next_indices = _get_start_indices(next_result_index)
diff = [_simplify_sympy_expr(a - b) for a, b in zip(next_indices, prev_indices)]
if diff != expected_diff:
is_contiguous = False
break

prev_indices = next_indices

mask = _build_mask(emitter, index, elements_per_thread)
if mask is None:
mask_vec_type = VectorType.get(
[elements_per_thread], IntegerType.get_signless(1)
)
mask = vector_d.constant_mask(mask_vec_type, [elements_per_thread])

if is_contiguous:
start_indices = _build_start_indices(emitter, result_index)
return start_indices, None, mask

start_indices_orig = _get_start_indices(index)

need_dynamic_offsets = False
Expand Down Expand Up @@ -665,13 +813,6 @@ def _construct_gather_scatter_indices(
offsets_vec_type, DenseElementsAttr.get(offsets, offsets_vec_type)
)

mask = _build_mask(emitter, index, elements_per_thread)
if mask is None:
mask_vec_type = VectorType.get(
[elements_per_thread], IntegerType.get_signless(1)
)
mask = vector_d.constant_mask(mask_vec_type, [elements_per_thread])

return start_indices, offsets_vec, mask


Expand Down Expand Up @@ -724,9 +865,14 @@ def handle_read(emitter: WaveEmitter, node: fx.Node):
zero = arith_d.ConstantOp(vector_type.element_type, zero)
passthru = vector_d.splat(vector_type, zero)

result = vector_d.gather(
vector_type, kb_src, start_indices, offsets_vec, mask, passthru
)
if offsets_vec is None:
result = vector_d.maskedload(
vector_type, kb_src, start_indices, mask, passthru
)
else:
result = vector_d.gather(
vector_type, kb_src, start_indices, offsets_vec, mask, passthru
)

emitter.bind_node_proxy(node, IRProxyValue(result))

Expand Down Expand Up @@ -781,7 +927,7 @@ def handle_write(emitter: WaveEmitter, node: fx.Node):
is_read=False,
)

if elements_per_thread == 1:
if offsets_vec is None:
vector_d.maskedstore(kb_dest, start_indices, mask, insert_vector)
else:
vector_d.scatter(kb_dest, start_indices, offsets_vec, mask, insert_vector)
Expand Down
46 changes: 27 additions & 19 deletions lit_tests/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def test(a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16]):
# CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
# CHECK: %[[ARR:.+]] = stream.binding.subspan %[[ARG0]][%[[C0]]] : !stream.binding -> memref<16x16xf16,
# CHECK-SAME: strided<[16, 1], offset: ?>>
# CHECK-DAG: %[[MASK:.+]] = vector.constant_mask [16] : vector<16xi1>
# CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index
# CHECK: %[[D0:.+]] = arith.muli %[[THREAD_ID_X]], %[[C16]] overflow<nsw, nuw> : index
# CHECK-DAG: %[[C16_0:.+]] = arith.constant 16 : index
Expand All @@ -143,8 +144,7 @@ def test(a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16]):
# CHECK-DAG: %[[C17:.+]] = arith.constant 17 : index
# CHECK: %[[D7:.+]] = arith.muli %[[THREAD_ID_Y]], %[[C17]] overflow<nsw, nuw> : index
# CHECK: %[[D8:.+]] = arith.addi %[[D7]], %[[D6]] overflow<nsw, nuw> : index
# CHECK: %[[CST:.+]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
# CHECK: %[[MASK:.+]] = vector.constant_mask [16] : vector<16xi1>
# CHECK-DAG: %[[CST:.+]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
# CHECK-DAG: %[[CST_2:.+]] = arith.constant 0.000000e+00 : f16
# CHECK: %[[D9:.+]] = vector.splat %[[CST_2]] : vector<16xf16>
# CHECK: %[[D10:.+]] = vector.gather %[[ARR]][%[[D5]], %[[D8]]] [%[[CST]]], %[[MASK]], %[[D9]] :
Expand Down Expand Up @@ -1026,22 +1026,27 @@ def test_igemm():
K = HF * WF * C
M = SZ_OUT * N

# Workgroup tile sizes
BLOCK_M = tkl.sym.BLOCK_M
BLOCK_N = tkl.sym.BLOCK_N
BLOCK_K = tkl.sym.BLOCK_K

i = tkw.IndexMapping.iterator(0)
j = tkw.IndexMapping.iterator(1)

x_mapping = tkw.IndexMapping(
num_iterators=2,
inputs={
N: i // SZ_OUT,
C: j // (HF * WF),
H: (i % SZ_OUT) % W_OUT * stride + (j % (HF * WF)) % WF,
W: (i % SZ_OUT) // W_OUT * stride + (j % (HF * WF)) // WF,
C: j % C,
H: (i % SZ_OUT) % W_OUT * stride + (j // C) % WF,
W: (i % SZ_OUT) // W_OUT * stride + (j // C) // WF,
},
outputs={M: i, K: j},
)
w_mapping = tkw.IndexMapping(
num_iterators=2,
inputs={NF: i % NF, C: j // (HF * WF), HF: j % WF, WF: (j % (HF * WF)) // WF},
inputs={NF: i % NF, C: j % C, HF: (j // C) % WF, WF: (j // C) // WF},
outputs={NF: i, K: j},
)
out_mapping = tkw.IndexMapping(
Expand All @@ -1055,10 +1060,6 @@ def test_igemm():
},
)

# Workgroup tile sizes
BLOCK_M = tkl.sym.BLOCK_M
BLOCK_N = tkl.sym.BLOCK_N
BLOCK_K = 16
# Address space (for GPU, shared(1) or global(0))
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE
# Other hyperparameters
Expand All @@ -1072,18 +1073,21 @@ def test_igemm():
we = torch.permute(we, (2, 3, 1, 0)).contiguous()
out = torch.permute(out, (0, 2, 3, 1)).contiguous()

ratio_m = 2
ratio_n = 2

# Expose user-constraints
constraints: list[tkw.Constraint] = []
constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(NF, BLOCK_N, 1)]
constraints += [tkw.WaveConstraint(M, BLOCK_M)]
constraints += [tkw.WaveConstraint(NF, BLOCK_N)]
constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)]
constraints += [tkw.WorkgroupConstraint(NF, BLOCK_N, 0)]
constraints += [tkw.WaveConstraint(M, BLOCK_M / ratio_m)]
constraints += [tkw.WaveConstraint(NF, BLOCK_N / ratio_n)]
constraints += [tkw.TilingConstraint(K, BLOCK_K)]

constraints += [
tkw.HardwareConstraint(
threads_per_wave=64,
waves_per_block=(1, 1, 1),
waves_per_block=(ratio_n, ratio_m, 1),
)
]

Expand Down Expand Up @@ -1123,8 +1127,9 @@ def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]:
NF: nf,
WF: wf,
HF: hf,
BLOCK_M: 16,
BLOCK_N: 16,
BLOCK_M: 64,
BLOCK_N: 128,
BLOCK_K: 32,
ELEMS_PER_THREAD: 4,
ADDRESS_SPACE: SHARED_ADDRESS_SPACE,
},
Expand All @@ -1134,8 +1139,11 @@ def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]:
# CHECK: func @conv
# CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index

# Check we are setting gather start indices to 0
# CHECK: %{{.*}} = vector.gather %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]] [%{{.*}}], %{{.*}}, %{{.*}} : memref<2x64x64x640xf16
# Input load must be contiguous.
# CHECK: %{{.*}} = vector.maskedload %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], %{{.*}}, %{{.*}} : memref<2x64x64x640xf16

# Weights are done via gather, check we are setting gather start indices to 0.
# CHECK: %{{.*}} = vector.gather %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]] [%{{.*}}], %{{.*}}, %{{.*}} : memref<3x3x640x640xf16
# CHECK: %{{.*}} = vector.gather %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]] [%{{.*}}], %{{.*}}, %{{.*}} : memref<3x3x640x640xf16


Expand Down
Loading
Loading