-
Notifications
You must be signed in to change notification settings - Fork 91
Multihead Attention Fixes #209
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multihead Attention Fixes #209
Conversation
I figured out that a little refactoring would benefit greatly. I'll add it and then this will be ready for review |
…e (will be needed for llama attention)
end subroutine common_forward | ||
|
||
pure module subroutine sdpa_forward(self, attention_mask) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Put Scaled Dot Product Attention into a separate method. This adds more flexibility.
In some cases we need to do manipulations with input projections, such as KV Caching for LLama and Qwen2.
Ready for review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
contains | ||
|
||
procedure :: common_backward | ||
procedure :: common_forward | ||
procedure :: sdpa_forward |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was wondering what was sdpa
until I found it in one of your comments below (that is, Scaled Dot Product Attention).
I suggest to add a comment to explain it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will do!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @OneAdder this looks good, see a few comments for further optimization that we've done before in dense % forward
. Worth trying if you don't mind, or let me know if I can help.
self % v_or_dv(:, :, head) = matmul(& | ||
transpose(self % attention_matrix(:, :, head)),& | ||
self % d_output(:, :, head)& | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this PR is doing some optimizations, there's an opportunity here as well. Intrinsic matmul
can often be slower than writing out the matrix multiplication by hand using loops, +
, and *
. transpose
creates a temporary copy, and can be avoided if we're looping and assigning values explicitly. When @jvdp1 made this optimization in dense % forward
, I think it cut the compute time in half.
|
||
! calculate delta for attention matrix | ||
d_sdpa = matmul(d_output(:, :, head), transpose(v_heads(:, :, head))) | ||
self % d_sdpa = matmul(self % d_output(:, :, head), transpose(self % v_heads(:, :, head))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As above, similar could be done here.
self % d_normalize(seq, :, head) = reshape(matmul(& | ||
reshape(self % d_sdpa(seq, :), [1, self % sequence_length]),& | ||
self % jacobian * self % scaling_factor& |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As above, in this case it may be even worse since we're doing reshape(matmul(reshape, ...))
self % q_or_dq(:, :, head) = matmul(self % d_normalize(:, :, head), self % k_heads(:, :, head)) | ||
|
||
! calculate delta for key, attention matrix should be transposed unlike for query | ||
dk(:, :, head) = matmul(transpose(d_normalize(:, :, head)), q_heads(:, :, head)) | ||
self % k_or_dk(:, :, head) = matmul(transpose(self % d_normalize(:, :, head)), self % q_heads(:, :, head)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As above
I don't know how to do this and I'm too lazy to learn since I have this in my editor already :) |
@milancurcic @jvdp1 Do you think replacing matmuls with custom |
Regarding the PR, I unfortunately have to slow down for a week or two. The xylophone valley startup I worked for has folded 🙈 |
Thank you for all the work so far, @OneAdder, and good luck with the job search! Feel free to get back to this whenever you're ready and I'll resume with the reviews of other PRs. |
Indeed, it's more complicated, as you write. Since the way you have it now is correct, any performance optimization can be left for a separate PR and we can investigate in more detail what approach we should take. So I'll just go ahead an merge this PR and we can revisit the optimizations later. |
Minor Changes to MHA