Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement lowerings for argmin and argmax #58

Merged
merged 3 commits into from
Nov 20, 2023
Merged

Conversation

nhat-nguyen
Copy link
Collaborator

Triton argmin and argmax both lower to tt.reduce ops that have identical
semantics identical to linalg.reduce op, so we can clone tt.reduce body to
linalg.reduce directly. Unfortunately, we still need to perform pattern matching
to know what reduce ops we are dealing with so that we know how to initialize
the initial reduce values correctly.

We can do this in a generic way without pattern matching by always using
the first elements along the reduction axis and perform the reduction on
the remaining elements. However, this results in creatings sub-tensors that
aren't always multiple of 2s, which are sub-optimal for certain hardware.

@nhat-nguyen nhat-nguyen merged commit de797bb into main Nov 20, 2023
2 checks passed
@nhat-nguyen nhat-nguyen deleted the nhat/argminmax branch November 20, 2023 17:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants