Skip to content

Multihead attention #199

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 73 commits into from
Feb 21, 2025
Merged

Conversation

OneAdder
Copy link
Collaborator

@OneAdder OneAdder commented Feb 9, 2025

Hello, Milan! I hope, I'm not bothering you too much with my pull requests, but this is a good one. At this stage it is a draft of MultiHead Attention. It cannot be merged until work on input2d_layer and linear2d_layer is completed.
Implementation of dropout would also help improve MHA, but it can be added later.

MultiHead Attention

MultiHead Attention is the main component of Transformer architecture, which is the most advanced modern approach in the area of Natural Language Processing, as well as some other areas.
Here I propose an implementation based on the Transformer article. It works and its output conforms with SOTA implementation in PyTorch.

Python Reference

https://github.com/OneAdder/neural-fortran-references/blob/main/self_attention.py

@ricor07
Copy link
Collaborator

ricor07 commented Feb 9, 2025

Hello Michael, I saw your pull requests and I think what you do is very interesting. Could you take a look at mine? milancurcic#2

What you have to look here is not the locally connected 1d layer but rather the reshape architecture I am trying to make. Your help would be very appreciated. Thanks

@milancurcic
Copy link
Member

Amazing, thank you! Yes, let's wrap up Input2d in #198, then Linear2d in #197, to avoid the 3d crutch.

@OneAdder OneAdder force-pushed the multihead_attention branch from 48d93b2 to 86cd7c0 Compare February 14, 2025 17:37
@ricor07
Copy link
Collaborator

ricor07 commented Feb 14, 2025

@OneAdder i see you are talking about a 2d handling. Would you like to make this together? I have to make this as well since I'm implementing a conv 1d layer

@milancurcic
Copy link
Member

Hi guys, thanks for pushing this forward. Today I'm finishing a lot of busywork with some proposals so next week I'm getting back more actively with neural-fortran work, and will be able to contribute to the reviews more actively

@OneAdder
Copy link
Collaborator Author

@ricor07 Great idea! But we have a problem with generics again. The issues is that a predict_2d and predict_1d_batch have the same rank of input. I think we should simply make a separate generic for predict_batch. @milancurcic your thoughts?

@ricor07
Copy link
Collaborator

ricor07 commented Feb 15, 2025

Yes, I think we can make a generic predict. But I suggest you to create a new branch

@milancurcic
Copy link
Member

I think it's fine to make predict_batch its own generic name because it's getting in the way. 👍

@OneAdder
Copy link
Collaborator Author

@milancurcic Done, here: #198
@ricor07 There is still one piece of the puzzle missing: a general flatten layer with interfaces for both 3D and 2D. Do you want to implement it? Or I can do it together with linear2d layer

@ricor07
Copy link
Collaborator

ricor07 commented Feb 15, 2025

You can make it. I'll work on maxpool

@OneAdder OneAdder force-pushed the multihead_attention branch 3 times, most recently from f9e7a7c to 0900990 Compare February 17, 2025 11:07
@OneAdder OneAdder marked this pull request as ready for review February 19, 2025 18:36
@OneAdder
Copy link
Collaborator Author

@milancurcic I think it's ready for review! I'll add the complicated example later. At this stage I added a simple example that converges nicely and doesn't require datasets and extra deps


implicit none

type, extends(multihead_attention_layer) :: cross_attention_layer
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is intentional that there is no plumbing for this one yet. I suggest that we add it at later stage when we have more components for seq2seq models. At this stage it can be added like this: without any public access

logical :: res

res = all(abs(x - y) <= (1e-06 + 1e-05 * abs(y)))
end function allclose
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion for future: create nf_utils.f90 (or similar) and put this procedure there

@OneAdder OneAdder requested a review from milancurcic February 19, 2025 18:47
end do
end subroutine create_attention_matrix

pure module subroutine normalize_attention_matrix(self, attention_mask)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

attention_mask is not accessible to the users by design at this point. It will be used by transformer decoder later and I'll add corresponding logic later

Copy link
Collaborator

@jvdp1 jvdp1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@OneAdder Nice set of PRs. I added a few suggestions. Feel free to ignore them.
Furthermore, I usually try to avoid reshape and using pointers instead, mainly for performance reasons. I think this approach could be used at some places, but it could be for a next PR.

@milancurcic
Copy link
Member

@OneAdder:

self_attention_layer and cross_attention_layer both extend multihead_attention_layer, but only self_attention is exposed in the user-facing module nf and in constructors.

  1. Should cross_attention be made accessible to the user like self_attention is?
  2. Is multihead_attention intended to be used only internally as a building block, or is there value to making it part of the public API?

The answers to these questions will also guide which layers to include in the layers table in the README.

@OneAdder
Copy link
Collaborator Author

OneAdder commented Feb 21, 2025

@jvdp1 Great idea! I think we should extend it even further and make smth like nf_utils with a subroutine that does those inplace reshapes. So that we could use it whenever possible. But I think this PR is already pretty long. So, I suggest we do it later

@OneAdder
Copy link
Collaborator Author

@milancurcic

  1. Yes, but at this point we don't have any layers to make it useful. Therefore I made it ready to be included but not actually included as of yet. I suggest it makes it to public API later.
  2. I don't know any scenarios where Attention is neither self nor cross. If there are, I think they should also be separate classes extending multihead_attention_layer, as it itself doesn't implement forward and backward methods

@milancurcic
Copy link
Member

Thank you!!

@milancurcic milancurcic merged commit ed8b340 into modern-fortran:main Feb 21, 2025
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants