diff --git a/iree/turbine/kernel/_support/indexing.py b/iree/turbine/kernel/_support/indexing.py index b99d7b5b..326927ed 100644 --- a/iree/turbine/kernel/_support/indexing.py +++ b/iree/turbine/kernel/_support/indexing.py @@ -430,3 +430,6 @@ def __repr__(self): if isinstance(self.size, int) and self.size <= 1: return f"{self.start}" return f"{self.start} : {self.size} : {self.stride}" + + def __hash__(self): + return hash((self.start, self.size, self.stride)) diff --git a/iree/turbine/kernel/ops/wave_ops.py b/iree/turbine/kernel/ops/wave_ops.py index b1c25440..8848b691 100644 --- a/iree/turbine/kernel/ops/wave_ops.py +++ b/iree/turbine/kernel/ops/wave_ops.py @@ -1349,6 +1349,14 @@ def num_reduction_dims(self) -> int: def reduction_dim(self) -> IndexSymbol: return self.dim + # In `set_node_indices` there is a logic, which propagates `elements_per_thread` + # from previous ops is if wasn't for the current op, which causes ReduceOp to + # get wrong indices. This function will prevent this propagation. + # TODO: remove after index handling is fully switched to thread_shape_analysis. + @property + def elements_per_thread(self) -> int: + return 1 + # TODO: Add support for more shuffle types. @define_op("shuffle") diff --git a/iree/turbine/kernel/wave/thread_shape_analysis.py b/iree/turbine/kernel/wave/thread_shape_analysis.py index 129f7551..49caa640 100644 --- a/iree/turbine/kernel/wave/thread_shape_analysis.py +++ b/iree/turbine/kernel/wave/thread_shape_analysis.py @@ -19,39 +19,48 @@ @dataclass(order=True) -class DimSize: +class DimIndex: dim: IndexSymbol - size: int + seq: IndexSequence + + @property + def size(self) -> IndexExpr: + return self.seq.size def __hash__(self): - return hash((self.dim, self.size)) + return hash((self.dim, self.seq)) -def get_dim_sizes(indices: list[IndexSequence]): - dims = frozenset( - [DimSize(dim, subs_idxc(seq.size)) for dim, seq in indices.items()] - ) +def process_seq(seq): + return subs_idxc(seq) + + +def get_dim_indices(indices: list[IndexSequence]): + dims = frozenset([DimIndex(dim, process_seq(seq)) for dim, seq in indices.items()]) return dims -def get_custom_dim_sizes(custom: CustomOp): - return get_dim_sizes(custom.index) +def get_custom_dim_indices(custom: CustomOp): + return get_dim_indices(custom.index) -def set_index_size(custom: CustomOp, target_dim_sizes: list[DimSize]): +def set_custom_index(custom: CustomOp, target_dim_sizes: list[DimIndex]): for target in target_dim_sizes: if target.dim not in custom.index: raise NotImplementedError( "NYI: Handle when source target index size is not found in target/user index." ) - custom.index[target.dim].size = target.size + custom.index[target.dim] = target.seq ################################################################# # Anchor Indicies and Conflict resolution helpers ################################################################# -anchorOpTypes = (Read, Write, MMA, ReduceOp, Reshape) +# TODO: Permute ops can have different indices on input and output. +# Add it to the anchorOpTypes to stop index propagation during forward/backward +# lookups. +anchorOpTypes = (Read, Write, MMA, ReduceOp, Reshape, Permute) noHandleTypes = (Placeholder, Output, ExtractSlice, Allocate) legalSubtypes = (IterArg,) nonPropagatableTypes = anchorOpTypes + noHandleTypes @@ -145,7 +154,7 @@ def determine_thread_shapes(trace: CapturedTrace): thread_shape in it's indexSequence. `thread_shapes` is used to store thread_size at every dimension that the op - cares about. We use a frozenset[DimSize] to represent it, where DimSize + cares about. We use a frozenset[DimIndex] to represent it, where DimIndex is essentially a pair. we are using frozen_set since we do not care about the order of dims for the shape/size propagation. @@ -171,15 +180,17 @@ def determine_thread_shapes(trace: CapturedTrace): """ anchor_ops = trace.walk(is_anchor_op) - thread_size_to_ops: dict[frozenset[DimSize], set[CustomOp]] = {} + thread_size_to_ops: dict[frozenset[DimIndex], set[CustomOp]] = {} + + def update_dims(index: frozenset[DimIndex], ops: set[CustomOp]): + thread_size_to_ops[index] = thread_size_to_ops.get(index, set([])).union(ops) + for anchor_op in anchor_ops: custom = get_custom(anchor_op) - index_sizes = get_custom_dim_sizes(custom) + index_sizes = get_custom_dim_indices(custom) if isinstance(custom, Read): fwd_slice = capture_forward_slice(custom.fx_node, propagatable_op) - thread_size_to_ops[index_sizes] = thread_size_to_ops.get( - index_sizes, set([]) - ).union(fwd_slice) + update_dims(index_sizes, fwd_slice) elif isinstance(custom, ReduceOp): fwd_slice = capture_forward_slice(custom.fx_node, propagatable_op) bwd_slice = set() @@ -188,18 +199,18 @@ def determine_thread_shapes(trace: CapturedTrace): ): bwd_slice = capture_backward_slice(custom.init, propagatable_op) reduce_dims = frozenset( - [DimSize(dim, 1) for dim in custom.index.keys() if dim != custom.dim] - ) - thread_size_to_ops[reduce_dims] = ( - thread_size_to_ops.get(reduce_dims, set([])) - .union(fwd_slice) - .union(bwd_slice) + [ + DimIndex(dim, process_seq(IndexSequence(seq.start, 1, 1))) + for dim, seq in custom.index.items() + if dim != custom.dim + ] ) + + update_dims(reduce_dims, fwd_slice) + update_dims(reduce_dims, bwd_slice) elif isinstance(custom, Write): bwd_slice = capture_backward_slice(custom.fx_node, propagatable_op) - thread_size_to_ops[index_sizes] = thread_size_to_ops.get( - index_sizes, set([]) - ).union(bwd_slice) + update_dims(index_sizes, bwd_slice) elif isinstance(custom, MMA): lhs_bwd_slice = set([custom.lhs]) if propagatable_op(custom.lhs): @@ -212,27 +223,19 @@ def determine_thread_shapes(trace: CapturedTrace): acc_slice = acc_slice.union( capture_backward_slice(custom.acc, propagatable_op) ) - acc_index = get_dim_sizes(custom.acc_index) - lhs_index = get_dim_sizes(custom.lhs_index) - rhs_index = get_dim_sizes(custom.rhs_index) - thread_size_to_ops[acc_index] = thread_size_to_ops.get( - acc_index, set([]) - ).union(acc_slice) - thread_size_to_ops[lhs_index] = thread_size_to_ops.get( - lhs_index, set([]) - ).union(lhs_bwd_slice) - thread_size_to_ops[rhs_index] = thread_size_to_ops.get( - rhs_index, set([]) - ).union(rhs_bwd_slice) + acc_index = get_dim_indices(custom.acc_index) + lhs_index = get_dim_indices(custom.lhs_index) + rhs_index = get_dim_indices(custom.rhs_index) + update_dims(acc_index, acc_slice) + update_dims(lhs_index, lhs_bwd_slice) + update_dims(rhs_index, rhs_bwd_slice) elif isinstance(custom, Reshape): # The reshape op acts like a barrier for the MMA preventing # the mma from propagating the thread shapes of its reshaped # operands backwards. - bwd_size = get_dim_sizes(custom.args.index) + bwd_size = get_dim_indices(custom.args.index) bwd_slice = capture_backward_slice(custom.args, propagatable_op) - thread_size_to_ops[bwd_size] = thread_size_to_ops.get( - bwd_size, set([]) - ).union(bwd_slice) + update_dims(bwd_size, bwd_slice) # Go through each index-size buckets, and apply the index-size to ops in the bucket. cummulative_set = set() @@ -241,10 +244,17 @@ def determine_thread_shapes(trace: CapturedTrace): if not cummulative_set.isdisjoint(target_ops): conflicted_ops = cummulative_set.intersection(target_ops) if handle_conflicts(conflicted_ops) == False: - raise NotImplementedError("Failed to handle conflicting thread shape.") + offenders = tuple( + (ops, dim) + for dim, ops in thread_size_to_ops.items() + if not conflicted_ops.isdisjoint(ops) + ) + raise NotImplementedError( + f"Failed to handle conflicting thread shape: {conflicted_ops}, {offenders}" + ) target_ops = target_ops.difference(conflicted_ops) cummulative_set = cummulative_set.union(target_ops) # Set target ops's indexSize to be the determined from analysis. for user in target_ops: custom_user = get_custom(user) - set_index_size(custom_user, target_index_size) + set_custom_index(custom_user, target_index_size)