Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TKW] Propagate dim index in thread shape analysis #288

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

Hardcode84
Copy link
Contributor

Refactor thread_shape_analysis to take into account entire index instead of just elements per thread count.

Signed-off-by: Ivan Butygin <[email protected]>
Signed-off-by: Ivan Butygin <[email protected]>
Signed-off-by: Ivan Butygin <[email protected]>
Signed-off-by: Ivan Butygin <[email protected]>
@Hardcode84 Hardcode84 changed the title [WIP] Propagate dim index [WIP] Propagate dim index in thread shape analysis Nov 21, 2024
@Hardcode84 Hardcode84 marked this pull request as ready for review November 22, 2024 00:08
@Hardcode84 Hardcode84 changed the title [WIP] Propagate dim index in thread shape analysis [TKW] Propagate dim index in thread shape analysis Nov 22, 2024
Copy link
Contributor

@harsh-nod harsh-nod left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm! just some requests for comments and also if you could check if we still need the propagation in index_sequence_analysis of elements_per_thread? thanks!

@@ -1349,6 +1349,10 @@ def num_reduction_dims(self) -> int:
def reduction_dim(self) -> IndexSymbol:
return self.dim

@property
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment explaining why this is necessary?

@@ -51,7 +57,7 @@ def set_index_size(custom: CustomOp, target_dim_sizes: list[DimSize]):
# Anchor Indicies and Conflict resolution helpers
#################################################################

anchorOpTypes = (Read, Write, MMA, ReduceOp, Reshape)
anchorOpTypes = (Read, Write, MMA, ReduceOp, Reshape, Permute)
noHandleTypes = (Placeholder, Output, ExtractSlice, Allocate)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also a comment explaining why Permute is added as an anchor op?

@@ -51,7 +57,7 @@ def set_index_size(custom: CustomOp, target_dim_sizes: list[DimSize]):
# Anchor Indicies and Conflict resolution helpers
#################################################################

anchorOpTypes = (Read, Write, MMA, ReduceOp, Reshape)
anchorOpTypes = (Read, Write, MMA, ReduceOp, Reshape, Permute)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add the forward propagation of permute, just as safety measure to ensure we won't be generating "valid" but incorrect IRs. Speaking from experience, would be much better for program to crash than to debug why MLIR is wrong and where the wrong is coming from. 😄

Signed-off-by: Ivan Butygin <[email protected]>
Signed-off-by: Ivan Butygin <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants