-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
104 lines (80 loc) · 4.39 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch
import torch.nn as nn
from transformers import AutoModel
class TourClassifier(nn.Module):
def __init__(self, n_classes1, n_classes2, n_classes3, text_model_name, image_model_name, device, dropout):
super(TourClassifier, self).__init__()
self.text_model = AutoModel.from_pretrained(text_model_name).to(device)
self.image_model = AutoModel.from_pretrained(image_model_name).to(device)
self.text_model.gradient_checkpointing_enable()
self.image_model.gradient_checkpointing_enable()
encoder_layer = nn.TransformerEncoderLayer(d_model=self.text_model.config.hidden_size, nhead=8).to(device)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2).to(device)
self.dropout_ratio = dropout
self.drop = nn.Dropout(p=dropout)
self.cls = self._get_cls(n_classes1)
self.cls2 = self._get_cls(n_classes2)
self.cls3 = self._get_cls(n_classes3)
def forward(self, input_ids, attention_mask, pixel_values):
text_output = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
image_output = self.image_model(pixel_values=pixel_values)
concat_outputs = torch.cat([text_output.last_hidden_state, image_output.last_hidden_state], 1)
outputs = self.transformer_encoder(concat_outputs)
outputs = outputs[:,0]
output = self.drop(outputs)
out1 = self.cls(output)
out2 = self.cls2(output)
out3 = self.cls3(output)
return out1, out2, out3
def _get_cls(self, target_size):
return nn.Sequential(
nn.Linear(self.text_model.config.hidden_size, self.text_model.config.hidden_size),
nn.LayerNorm(self.text_model.config.hidden_size),
nn.Dropout(p=self.dropout_ratio),
nn.ReLU(),
nn.Linear(self.text_model.config.hidden_size, target_size),
)
class TourClassifier_Separate(nn.Module):
def __init__(self, n_classes1, n_classes2, n_classes3, text_model_name, image_model_name, device, dropout, alpha):
super(TourClassifier_Separate, self).__init__()
self.text_model = AutoModel.from_pretrained(text_model_name).to(device)
self.image_model = AutoModel.from_pretrained(image_model_name).to(device)
self.text_model.gradient_checkpointing_enable()
self.image_model.gradient_checkpointing_enable()
self.dropout_ratio = dropout
self.drop = nn.Dropout(p=dropout)
self.text_cls = self._get_cls(self.text_model.config.hidden_size, n_classes1)
self.text_cls2 = self._get_cls(self.text_model.config.hidden_size, n_classes2)
self.text_cls3 = self._get_cls(self.text_model.config.hidden_size, n_classes3)
if 'vit' in image_model_name or 'swinv2' in image_model_name:
image_hidden_size = self.image_model.config.hidden_size
else:
image_hidden_size = self.image_model.config.hidden_sizes[-1]
self.image_cls = self._get_cls(image_hidden_size, n_classes1)
self.image_cls2 = self._get_cls(image_hidden_size, n_classes2)
self.image_cls3 = self._get_cls(image_hidden_size, n_classes3)
self.alpha = alpha
def forward(self, input_ids, attention_mask, pixel_values):
text_output = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
image_output = self.image_model(pixel_values=pixel_values)
text_output = self.drop(text_output.pooler_output)
image_output = self.drop(image_output.pooler_output)
image_output = torch.squeeze(image_output)
text_out1 = self.text_cls(text_output)
text_out2 = self.text_cls2(text_output)
text_out3 = self.text_cls3(text_output)
image_out1 = self.image_cls(image_output)
image_out2 = self.image_cls2(image_output)
image_out3 = self.image_cls3(image_output)
out1 = torch.add(text_out1, image_out1, alpha=self.alpha)
out2 = torch.add(text_out2, image_out2, alpha=self.alpha)
out3 = torch.add(text_out3, image_out3, alpha=self.alpha)
return out1, out2, out3
def _get_cls(self, hidden_size, target_size):
return nn.Sequential(
nn.Linear(hidden_size, hidden_size),
nn.LayerNorm(hidden_size),
nn.Dropout(p=self.dropout_ratio),
nn.ReLU(),
nn.Linear(hidden_size, target_size),
)