-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
43 lines (36 loc) · 2.13 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel, BertConfig
class SimcseModel(nn.Module):
"""Simcse无监督模型定义"""
def __init__(self, pretrained_model, pooling, dropout=0.3):
super(SimcseModel, self).__init__()
# config = SimBertConfig.from_pretrained(pretrained_model)
config = BertConfig.from_pretrained(pretrained_model)
config.attention_probs_dropout_prob = dropout # 修改config的dropout系数
config.hidden_dropout_prob = dropout
self.bert = BertModel.from_pretrained(pretrained_model, config=config)
# self.bert = SimBertModel.from_pretrained(pretrained_model, config=config)
self.pooling = pooling
def forward(self, input_ids, attention_mask, token_type_ids):
sql_len = input_ids.shape[-1]
input_ids = input_ids.view(-1, sql_len)
attention_mask = attention_mask.view(-1, sql_len)
token_type_ids = token_type_ids.view(-1, sql_len)
out = self.bert(input_ids, attention_mask, token_type_ids, output_hidden_states=True, return_dict=True)
# return out[1]
if self.pooling == 'cls':
return out.last_hidden_state[:, 0] # [batch, 768]
if self.pooling == 'pooler':
return out.pooler_output # [batch, 768]
if self.pooling == 'last-avg':
last = out.last_hidden_state.transpose(1, 2) # [batch, 768, seqlen]
return torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1) # [batch, 768]
if self.pooling == 'first-last-avg':
first = out.hidden_states[1].transpose(1, 2) # [batch, 768, seqlen]
last = out.hidden_states[-1].transpose(1, 2) # [batch, 768, seqlen]
first_avg = torch.avg_pool1d(first, kernel_size=last.shape[-1]).squeeze(-1) # [batch, 768]
last_avg = torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1) # [batch, 768]
avg = torch.cat((first_avg.unsqueeze(1), last_avg.unsqueeze(1)), dim=1) # [batch, 2, 768]
return torch.avg_pool1d(avg.transpose(1, 2), kernel_size=2).squeeze(-1) # [batch, 768]