-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathmodel_configs.py
66 lines (61 loc) · 2.05 KB
/
model_configs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from modeling.vilt import *
from modeling.viltbert import *
ALLOWED_CL_ENCODERS = ['vilt', 'viltbert']
vilt_config = {
'encoder_dim': 768,
'visual_input_type': 'pil-image',
'encoder_class': ViltEncoderWrapper,
'batch2inputs_converter': convert_batch_to_vilt_input_dict,
'encoder_name': 'ViLT'
}
vilt_lang_seq_config = {
'encoder_dim': 768,
'visual_input_type': 'pil-image',
'encoder_class': ViltEncoderWrapper,
'classifier_class': ViltForSequenceClassification,
'batch2inputs_converter': convert_seq_batch_to_vilt_input_dict
}
vilt_lang_mc_config = {
'encoder_dim': 768,
'visual_input_type': 'pil-image',
'encoder_class': ViltEncoderWrapper,
'classifier_class': ViltForMultipleChoice,
'batch2inputs_converter': convert_mc_batch_to_vilt_input_dict
}
vilt_vision_cls_config = {
'encoder_dim': 768,
'visual_input_type': 'pil-image',
'encoder_class': ViltEncoderWrapper,
'classifier_class': ViltForImageClassification,
'batch2inputs_converter': convert_batch_to_vilt_input_dict
}
viltbert_config = {
'encoder_dim': 768,
'visual_input_type': 'pil-image',
'encoder_class': ViltBertEncoderWrapper,
'batch2inputs_converter': convert_batch_to_viltbert_input_dict,
'encoder_name': 'ViLT-BERT'
}
viltbert_lang_seq_config = {
'encoder_dim': 768,
'visual_input_type': 'pil-image',
'encoder_class': ViltBertEncoderWrapper,
'classifier_class': ViltBertForSequenceClassification,
'batch2inputs_converter': convert_seq_batch_to_vilt_input_dict
}
viltbert_lang_mc_config = {
'encoder_dim': 768,
'visual_input_type': 'pil-image',
'encoder_class': ViltBertEncoderWrapper,
'classifier_class': ViltBertForMultipleChoice,
'batch2inputs_converter': convert_mc_batch_to_vilt_input_dict
}
model_configs = {
'vilt': vilt_config,
'vilt-v-cls': vilt_vision_cls_config,
'vilt-l-seq': vilt_lang_seq_config,
'vilt-l-mc': vilt_lang_mc_config,
'viltbert': viltbert_config,
'viltbert-l-seq': viltbert_lang_seq_config,
'viltbert-l-mc': viltbert_lang_mc_config
}