-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Performance on other PLM #1
Comments
Hi @Hannibal046, |
Hi, |
Hi, import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class AttentionMerge(nn.Module):
def __init__(self,input_size):
super().__init__()
self.query_ = nn.Parameter(torch.Tensor(input_size,1))
self.query_.data.normal_(mean=0.0,std=0.02)
def forward(
self,
values,
mask = None ## 1 for non-pad token, common usage in Huggingface/Transformers
):
## value: [bs,length,d_model]
## mask: [bs,length]
bs,length,d_model = values.shape
if mask is not None:
mask = mask.unsqueeze(-1) # bs,length,1
inverted_mask = 1.0 - mask
mask = inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(values.dtype).min)
else:
## assume there is no pad token
mask = torch.ones((bs,length,1))
inverted_mask = 1.0 - mask
mask = inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(values.dtype).min)
attention_probs = values @ self.query_ # [bs,length,1]
attention_probs = attention_probs + mask # pad token is set to -inf
attention_probs = F.softmax(attention_probs,dim=1) # bs,l,1
return torch.bmm(attention_probs.permute(0,2,1),values).squeeze(1) # bs,d_model |
Hello,
Amazing work! Did you ever try other PLM(Bert,Roberta...) as your backbone model? Or did they perform not well in your preliminary experiments? Thanks so much
The text was updated successfully, but these errors were encountered: