-
Notifications
You must be signed in to change notification settings - Fork 91
Embedding Layer #205
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Embedding Layer #205
Conversation
@jvdp1 @milancurcic Ready for review |
90d80d2
to
73799bd
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Below are some minor suggestions.
@milancurcic Your opinion? |
Thanks for the nudge. Will try to finish review and merge either today or Monday. Thank you for all the hard work! |
@OneAdder Can you add the Embedding layer entry to the table of layers in the README? I assume that based on it being provided in |
@milancurcic Readme updated. Thank you for resolving conflicts! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
Input Embeddings Lookup Table (Trainable
Core
In Natural Language Processing input data is often encoded as indices of tokens in a vocabulary. Those indices are converted into vectors using weights which are trainable.
In theory, similar behaviour can be achieved by spreading input data by the desired size of vectors and then put through
input2d
andlinear2d
layers. However, it is very inefficient as we'll have to domatmul
each time instead of simply getting an element by indices.So, I created the layer that does it efficiently. It does not have a gradient as it is intended as an input layer. But it is trainable (
get_params
,get_gradients
andset_params
).Positional
Apart from this core functionality, I also added positional encoding (output vectors + positions assigned to sin/cos waves). This will be needed for transformers.
Python Reference
Here:
torch.nn.Embedding
and a custom function for positional encoding.