Skip to content

Commit

Permalink
Add evoformer example
Browse files Browse the repository at this point in the history
Signed-off-by: Harsh Menon <[email protected]>
  • Loading branch information
harsh-nod committed Nov 22, 2024
1 parent 59e5fe0 commit 165c51c
Show file tree
Hide file tree
Showing 6 changed files with 258 additions and 25 deletions.
2 changes: 2 additions & 0 deletions iree/turbine/kernel/lang/global_symbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
WORKGROUP_0 = index_symbol("$WG0")
WORKGROUP_1 = index_symbol("$WG1")
WORKGROUP_2 = index_symbol("$WG2")
WORKGROUP_3 = index_symbol("$WG3")
WORKGROUP_4 = index_symbol("$WG4")

THREAD_0 = index_symbol("$T0")
THREAD_1 = index_symbol("$T1")
Expand Down
29 changes: 19 additions & 10 deletions iree/turbine/kernel/wave/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,24 @@ class WorkgroupConstraint(Constraint):
tile_size: IndexExpr
workgroup_dim: int

def __post_init__(self):
self.wg_dim = None
match self.workgroup_dim:
case 0:
self.wg_dim = WORKGROUP_0
case 1:
self.wg_dim = WORKGROUP_1
case 2:
self.wg_dim = WORKGROUP_2
case 3:
self.wg_dim = WORKGROUP_3
case 4:
self.wg_dim = WORKGROUP_4
case _:
raise ValueError(
"Invalid workgroup dimension. Expected 0, 1, 2, 3 or 4."
)

@property
def count(self) -> IndexExpr:
"""
Expand All @@ -332,16 +350,7 @@ def count(self) -> IndexExpr:
return ceiling(self.dim / self.tile_size)

def apply(self) -> IndexSequence:
match self.workgroup_dim:
case 0:
wg_dim = WORKGROUP_0
case 1:
wg_dim = WORKGROUP_1
case 2:
wg_dim = WORKGROUP_2
case _:
raise ValueError("Invalid workgroup dimension. Expected 0, 1 or 2.")
return IndexSequence(wg_dim * self.tile_size, 1)
return IndexSequence(self.wg_dim * self.tile_size, 1)


def get_grid_shape(wg_constraints: list[WorkgroupConstraint]) -> list[IndexExpr]:
Expand Down
57 changes: 45 additions & 12 deletions iree/turbine/kernel/wave/thread_shape_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ def get_custom_dim_sizes(custom: CustomOp):
def set_index_size(custom: CustomOp, target_dim_sizes: list[DimSize]):
for target in target_dim_sizes:
if target.dim not in custom.index:
# Allow target dimensions to be missing in the source index only if they are unit dims.
if target.size == 1:
continue
raise NotImplementedError(
"NYI: Handle when source target index size is not found in target/user index."
)
Expand Down Expand Up @@ -68,6 +71,17 @@ def propagatable_op(node: fx.Node):
)


def propagate_resolutions(
custom_node: CustomOp, dst_op: CustomOp = None
) -> list[fx.Node]:
propagated_resolutions = capture_forward_slice(custom_node.fx_node, propagatable_op)
if dst_op:
for node in propagated_resolutions:
get_custom(node).index = dst_op.index
resolved_resolutions = capture_backward_slice(custom_node.fx_node, propagatable_op)
return propagated_resolutions.union(resolved_resolutions)


def handle_binaryop_conflict(custom_node: CustomOp) -> list[fx.Node]:
"""
This function will attempt to resolve binaryOp conflicts
Expand All @@ -81,8 +95,7 @@ def handle_binaryop_conflict(custom_node: CustomOp) -> list[fx.Node]:
lhs_dim_set = set(lhs.type.symbolic_shape)
rhs_dim_set = set(rhs.type.symbolic_shape)
if lhs_dim_set == rhs_dim_set:
# Could be caused by consumers(likely also binaryOp) of this node.
return []
return propagate_resolutions(custom_node)
if lhs_dim_set.isdisjoint(rhs_dim_set):
raise ValueError("Cannot broadcast if lhs and rhs has disjointed shapes.")
# Determine the correct indexSize for binaryOp and insert broadcasting.
Expand All @@ -96,15 +109,7 @@ def handle_binaryop_conflict(custom_node: CustomOp) -> list[fx.Node]:
custom_broadcast.vector_shapes = broadcast_src.vector_shapes
custom_broadcast.anchor = broadcast_src.anchor
custom_node.update_arg(broadcast_idx, custom_broadcast.fx_node)
propagated_resolutions = capture_forward_slice(
custom_broadcast.fx_node, propagatable_op
)
for node in propagated_resolutions:
get_custom(node).index = dst_op.index
resolved_resolutions = capture_backward_slice(
custom_broadcast.fx_node, propagatable_op
)
return propagated_resolutions.union(resolved_resolutions)
return propagate_resolutions(custom_broadcast, dst_op)


# Returns True iff all conflicts are handled succesfully.
Expand All @@ -123,6 +128,26 @@ def handle_conflicts(conflicted_ops: set[CustomOp]):
return all_conflicts_resolved


def need_conflict_resolution(
op_to_thread_sizes: dict[CustomOp, set[frozenset[DimSize]]],
conflicted_ops: set[CustomOp],
target_index_size: frozenset[DimSize],
):
"""
Determine if we need to resolve conflicts. We need to resolve conflicts
only if the sizes along non-unit dims are not identical.
"""
for op in conflicted_ops:
if op not in op_to_thread_sizes:
continue
different_shapes = op_to_thread_sizes[op].symmetric_difference(
target_index_size
)
if any([dim.size != 1 for dim in different_shapes]):
return True
return False


###############################################################################
# Main pass
#####################################################################
Expand Down Expand Up @@ -236,15 +261,23 @@ def determine_thread_shapes(trace: CapturedTrace):

# Go through each index-size buckets, and apply the index-size to ops in the bucket.
cummulative_set = set()
# Maintains the last thread size that was set for an op.
ops_to_thread_sizes: dict[CustomOp, frozenset[DimSize]] = {}
for target_index_size, target_ops in thread_size_to_ops.items():
# Try to handle conflicts and remove from target set if successfully handled.
if not cummulative_set.isdisjoint(target_ops):
conflicted_ops = cummulative_set.intersection(target_ops)
if handle_conflicts(conflicted_ops) == False:
if (
need_conflict_resolution(
ops_to_thread_sizes, conflicted_ops, target_index_size
)
and handle_conflicts(conflicted_ops) == False
):
raise NotImplementedError("Failed to handle conflicting thread shape.")
target_ops = target_ops.difference(conflicted_ops)
cummulative_set = cummulative_set.union(target_ops)
# Set target ops's indexSize to be the determined from analysis.
for user in target_ops:
custom_user = get_custom(user)
set_index_size(custom_user, target_index_size)
ops_to_thread_sizes[user] = target_index_size
4 changes: 3 additions & 1 deletion iree/turbine/kernel/wave/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,9 @@ def is_chained_extractslice(node: fx.Node) -> bool:
get_custom(node).graph.erase_node(src_extract.fx_node)


def delinearize_index(index: IndexExpr, shape: list[int]) -> list[IndexExpr]:
def delinearize_index(
index: IndexExpr, shape: list[int | IndexExpr]
) -> list[IndexExpr]:
"""
Delinearizes a 1D index into a multi-dimensional index
based on the shapes provided. The returned array contains
Expand Down
27 changes: 25 additions & 2 deletions iree/turbine/kernel/wave/wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
remove_chained_getresult,
remove_chained_extractslice,
subs_idxc,
delinearize_index,
)
from .minimize_global_loads import minimize_global_loads
from .decompose_reduce_ops import decompose_reduce_ops
Expand Down Expand Up @@ -207,6 +208,23 @@ def initialize_reductions(self, trace: CapturedTrace) -> None:
if tiling_constraint.dim == get_custom(reduction).axis:
reduction.count = subs_idxc(tiling_constraint.count)

def initialize_workgroup_constraints(self, trace: CapturedTrace) -> None:
"""
For kernels that distribute more than three dimensions among workgroups,
we need to update the workgroup constraints for dimensions >= 2
with the appropriate workgroup index.
"""
workgroup_dims = {x.workgroup_dim: x for x in self.workgroup_constraints}
if all(x <= 2 for x in workgroup_dims.keys()):
return
shape = [
subs_idxc(workgroup_dims[i].count)
for i in range(2, max(workgroup_dims.keys()) + 1)
]
new_workgroup_dims = delinearize_index(WORKGROUP_2, shape)
for i in range(2, max(workgroup_dims.keys()) + 1):
workgroup_dims[i].wg_dim = new_workgroup_dims[i - 2]

def _trace_and_get_kernel_signature(
self,
args,
Expand All @@ -220,6 +238,7 @@ def _trace_and_get_kernel_signature(
self.create_induction_vars(graph)
self.initialize_wave_constraints(graph)
self.initialize_reductions(graph)
self.initialize_workgroup_constraints(graph)

idxc = IndexingContext.current()
idxc.finalize()
Expand Down Expand Up @@ -287,10 +306,14 @@ def _trace_and_get_kernel_signature(

# Determine grid shape.
self.grid_type.dims = [1, 1, 1]
max_workgroup_dim = 2
for constraint in self.workgroup_constraints:
self.grid_type.dims[constraint.workgroup_dim] = safe_subs(
constraint.count, idxc.subs
dim = (
constraint.workgroup_dim
if constraint.workgroup_dim < max_workgroup_dim
else max_workgroup_dim
)
self.grid_type.dims[dim] *= safe_subs(constraint.count, idxc.subs)
grid = self.grid_type

root_graph = graph.get_root_graph()
Expand Down
Loading

0 comments on commit 165c51c

Please sign in to comment.