-
Notifications
You must be signed in to change notification settings - Fork 31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable running softmax with TPPs #584
Comments
This IR looks similar to what we generate from The MHA softmax is a little different, we need both styles covered.
softmax in libxsmm is lowered as an equation, and just calling the kernels one after another is very close to optimal. I would not create complicated machinery that is specific to certain complex patterns unless the benefit was very large and there was no other way. Softmax will eventually be lowered as an equation, which is the right way long term, so we can live with most of the performance now and the rest later. |
Yes, calling the kernel one after the other would be the plan. Still, we must either fuse along 64 and 8 to extract 2d tensors or materialize the two outermost dimensions for each linalg ops and replace the body with a tpp operation. Do you have an example of the IR generated by |
Yup. just run Also, just to be clear, this is really low priority. Finding the right shapes for MHA and finally getting TPP on tensors in the main pipeline are still the most important tasks right now. |
To enable running softmax with TPPs we need more operations:
The IR below shows a Softmax example in Linalg, extracted from a self-attention layer. The lowering is: TF dialect -> StableHLO -> Linalg IR. To lower from TF dialect to StableHLO we use tf-opt while from StableHLO to linalg we use the IREE compiler and print after
iree-stablehlo-to-iree-input
.The dimension of
arg0
are:[B, heads, T, T]
where B is the batched dimension, heads is the number of heads, while T is the sequence length.Related to #414.
TBD:
The text was updated successfully, but these errors were encountered: