-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TKW] Propagate dim index in thread shape analysis #288
base: main
Are you sure you want to change the base?
Changes from all commits
a74ba32
a19acf0
69b9d6a
eabe087
4e1dbcd
2579cec
2ca85f1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,39 +19,48 @@ | |
|
||
|
||
@dataclass(order=True) | ||
class DimSize: | ||
class DimIndex: | ||
dim: IndexSymbol | ||
size: int | ||
seq: IndexSequence | ||
|
||
@property | ||
def size(self) -> IndexExpr: | ||
return self.seq.size | ||
|
||
def __hash__(self): | ||
return hash((self.dim, self.size)) | ||
return hash((self.dim, self.seq)) | ||
|
||
|
||
def get_dim_sizes(indices: list[IndexSequence]): | ||
dims = frozenset( | ||
[DimSize(dim, subs_idxc(seq.size)) for dim, seq in indices.items()] | ||
) | ||
def process_seq(seq): | ||
return subs_idxc(seq) | ||
|
||
|
||
def get_dim_indices(indices: list[IndexSequence]): | ||
dims = frozenset([DimIndex(dim, process_seq(seq)) for dim, seq in indices.items()]) | ||
return dims | ||
|
||
|
||
def get_custom_dim_sizes(custom: CustomOp): | ||
return get_dim_sizes(custom.index) | ||
def get_custom_dim_indices(custom: CustomOp): | ||
return get_dim_indices(custom.index) | ||
|
||
|
||
def set_index_size(custom: CustomOp, target_dim_sizes: list[DimSize]): | ||
def set_custom_index(custom: CustomOp, target_dim_sizes: list[DimIndex]): | ||
for target in target_dim_sizes: | ||
if target.dim not in custom.index: | ||
raise NotImplementedError( | ||
"NYI: Handle when source target index size is not found in target/user index." | ||
) | ||
custom.index[target.dim].size = target.size | ||
custom.index[target.dim] = target.seq | ||
|
||
|
||
################################################################# | ||
# Anchor Indicies and Conflict resolution helpers | ||
################################################################# | ||
|
||
anchorOpTypes = (Read, Write, MMA, ReduceOp, Reshape) | ||
# TODO: Permute ops can have different indices on input and output. | ||
# Add it to the anchorOpTypes to stop index propagation during forward/backward | ||
# lookups. | ||
anchorOpTypes = (Read, Write, MMA, ReduceOp, Reshape, Permute) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add the forward propagation of permute, just as safety measure to ensure we won't be generating "valid" but incorrect IRs. Speaking from experience, would be much better for program to crash than to debug why MLIR is wrong and where the wrong is coming from. 😄 |
||
noHandleTypes = (Placeholder, Output, ExtractSlice, Allocate) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also a comment explaining why Permute is added as an anchor op? |
||
legalSubtypes = (IterArg,) | ||
nonPropagatableTypes = anchorOpTypes + noHandleTypes | ||
|
@@ -145,7 +154,7 @@ def determine_thread_shapes(trace: CapturedTrace): | |
thread_shape in it's indexSequence. | ||
|
||
`thread_shapes` is used to store thread_size at every dimension that the op | ||
cares about. We use a frozenset[DimSize] to represent it, where DimSize | ||
cares about. We use a frozenset[DimIndex] to represent it, where DimIndex | ||
is essentially a pair<dimension: IndexSymbol, thread_size: int>. we are using | ||
frozen_set since we do not care about the order of dims for the shape/size | ||
propagation. | ||
|
@@ -171,15 +180,17 @@ def determine_thread_shapes(trace: CapturedTrace): | |
|
||
""" | ||
anchor_ops = trace.walk(is_anchor_op) | ||
thread_size_to_ops: dict[frozenset[DimSize], set[CustomOp]] = {} | ||
thread_size_to_ops: dict[frozenset[DimIndex], set[CustomOp]] = {} | ||
|
||
def update_dims(index: frozenset[DimIndex], ops: set[CustomOp]): | ||
thread_size_to_ops[index] = thread_size_to_ops.get(index, set([])).union(ops) | ||
|
||
for anchor_op in anchor_ops: | ||
custom = get_custom(anchor_op) | ||
index_sizes = get_custom_dim_sizes(custom) | ||
index_sizes = get_custom_dim_indices(custom) | ||
if isinstance(custom, Read): | ||
fwd_slice = capture_forward_slice(custom.fx_node, propagatable_op) | ||
thread_size_to_ops[index_sizes] = thread_size_to_ops.get( | ||
index_sizes, set([]) | ||
).union(fwd_slice) | ||
update_dims(index_sizes, fwd_slice) | ||
elif isinstance(custom, ReduceOp): | ||
fwd_slice = capture_forward_slice(custom.fx_node, propagatable_op) | ||
bwd_slice = set() | ||
|
@@ -188,18 +199,18 @@ def determine_thread_shapes(trace: CapturedTrace): | |
): | ||
bwd_slice = capture_backward_slice(custom.init, propagatable_op) | ||
reduce_dims = frozenset( | ||
[DimSize(dim, 1) for dim in custom.index.keys() if dim != custom.dim] | ||
) | ||
thread_size_to_ops[reduce_dims] = ( | ||
thread_size_to_ops.get(reduce_dims, set([])) | ||
.union(fwd_slice) | ||
.union(bwd_slice) | ||
[ | ||
DimIndex(dim, process_seq(IndexSequence(seq.start, 1, 1))) | ||
for dim, seq in custom.index.items() | ||
if dim != custom.dim | ||
] | ||
) | ||
|
||
update_dims(reduce_dims, fwd_slice) | ||
update_dims(reduce_dims, bwd_slice) | ||
elif isinstance(custom, Write): | ||
bwd_slice = capture_backward_slice(custom.fx_node, propagatable_op) | ||
thread_size_to_ops[index_sizes] = thread_size_to_ops.get( | ||
index_sizes, set([]) | ||
).union(bwd_slice) | ||
update_dims(index_sizes, bwd_slice) | ||
elif isinstance(custom, MMA): | ||
lhs_bwd_slice = set([custom.lhs]) | ||
if propagatable_op(custom.lhs): | ||
|
@@ -212,27 +223,19 @@ def determine_thread_shapes(trace: CapturedTrace): | |
acc_slice = acc_slice.union( | ||
capture_backward_slice(custom.acc, propagatable_op) | ||
) | ||
acc_index = get_dim_sizes(custom.acc_index) | ||
lhs_index = get_dim_sizes(custom.lhs_index) | ||
rhs_index = get_dim_sizes(custom.rhs_index) | ||
thread_size_to_ops[acc_index] = thread_size_to_ops.get( | ||
acc_index, set([]) | ||
).union(acc_slice) | ||
thread_size_to_ops[lhs_index] = thread_size_to_ops.get( | ||
lhs_index, set([]) | ||
).union(lhs_bwd_slice) | ||
thread_size_to_ops[rhs_index] = thread_size_to_ops.get( | ||
rhs_index, set([]) | ||
).union(rhs_bwd_slice) | ||
acc_index = get_dim_indices(custom.acc_index) | ||
lhs_index = get_dim_indices(custom.lhs_index) | ||
rhs_index = get_dim_indices(custom.rhs_index) | ||
update_dims(acc_index, acc_slice) | ||
update_dims(lhs_index, lhs_bwd_slice) | ||
update_dims(rhs_index, rhs_bwd_slice) | ||
elif isinstance(custom, Reshape): | ||
# The reshape op acts like a barrier for the MMA preventing | ||
# the mma from propagating the thread shapes of its reshaped | ||
# operands backwards. | ||
bwd_size = get_dim_sizes(custom.args.index) | ||
bwd_size = get_dim_indices(custom.args.index) | ||
bwd_slice = capture_backward_slice(custom.args, propagatable_op) | ||
thread_size_to_ops[bwd_size] = thread_size_to_ops.get( | ||
bwd_size, set([]) | ||
).union(bwd_slice) | ||
update_dims(bwd_size, bwd_slice) | ||
|
||
# Go through each index-size buckets, and apply the index-size to ops in the bucket. | ||
cummulative_set = set() | ||
|
@@ -241,10 +244,17 @@ def determine_thread_shapes(trace: CapturedTrace): | |
if not cummulative_set.isdisjoint(target_ops): | ||
conflicted_ops = cummulative_set.intersection(target_ops) | ||
if handle_conflicts(conflicted_ops) == False: | ||
raise NotImplementedError("Failed to handle conflicting thread shape.") | ||
offenders = tuple( | ||
(ops, dim) | ||
for dim, ops in thread_size_to_ops.items() | ||
if not conflicted_ops.isdisjoint(ops) | ||
) | ||
raise NotImplementedError( | ||
f"Failed to handle conflicting thread shape: {conflicted_ops}, {offenders}" | ||
) | ||
target_ops = target_ops.difference(conflicted_ops) | ||
cummulative_set = cummulative_set.union(target_ops) | ||
# Set target ops's indexSize to be the determined from analysis. | ||
for user in target_ops: | ||
custom_user = get_custom(user) | ||
set_index_size(custom_user, target_index_size) | ||
set_custom_index(custom_user, target_index_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a comment explaining why this is necessary?