forked from rebnej/lick-caption-bias
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrace_dataset.py
160 lines (127 loc) · 5.89 KB
/
race_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import argparse
import pickle
import nltk
import numpy as np
import json
import os
import pprint
from nltk.tokenize import word_tokenize
import random
from io import open
import sys
import torch
from torch import nn
import torch.utils.data as data
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm, trange
class BERT_ANN_leak_data(data.Dataset):
def __init__(self, d_train, d_test, args, race_task_entries, race_words, tokenizer, max_seq_length, split, caption_ind=None):
self.task = args.task
#self.id_2_val_obj_cap_entries = id_2_val_obj_cap_entries
self.race_task_entries = race_task_entries
self.cap_ind = caption_ind
self.split = split
self.race_words = race_words
self.tokenizer = tokenizer
self.max_seq_length = max_seq_length
self.d_train, self.d_test = d_train, d_test
self.align_vocab = args.align_vocab
if self.align_vocab:
self.model_vocab = pickle.load(open('./bias_data/model_vocab/%s_vocab.pkl' %args.cap_model, 'rb'))
print('len(self.model_vocab):', len(self.model_vocab))
def __len__(self):
if self.split == 'train':
return len(self.d_train)
else:
return len(self.d_test)
def __getitem__(self, index):
if self.split == 'train':
entries = self.d_train
else:
entries = self.d_test
entry = entries[index]
img_id = entry['img_id']
race = entry['bb_skin']
if race == 'Light':
race_target = torch.tensor(0)
elif race == 'Dark':
race_target = torch.tensor(1)
if self.task == 'captioning':
ctokens = word_tokenize(entry['caption_list'][self.cap_ind].lower())
new_list = []
for t in ctokens:
if t in self.race_words:
new_list.append('[MASK]')
elif self.align_vocab:
if t not in self.model_vocab:
new_list.append('[UNK]')
else:
new_list.append(t)
else:
new_list.append(t)
new_sent = ' '.join([c for c in new_list])
encoded_dict = self.tokenizer.encode_plus(new_sent, add_special_tokens=True, truncation=True, max_length=self.max_seq_length,
padding='max_length', return_attention_mask=True, return_tensors='pt')
elif self.task == 'vqa':
masked_ann_concat_sent = entry['masked_ann_concat_sent']
masked_ann_concat_sent = masked_ann_concat_sent.replace('genderword', '[MASK]') #for BERT
encoded_dict = self.tokenizer.encode_plus(masked_ann_concat_sent, add_special_tokens=True, truncation=True, max_length=self.max_seq_length,
padding='max_length', return_attention_mask=True, return_tensors='pt')
input_ids = encoded_dict['input_ids']
attention_mask = encoded_dict['attention_mask']
token_type_ids = encoded_dict['token_type_ids']
token_type_ids = token_type_ids.view(self.max_seq_length)
input_ids = input_ids.view(self.max_seq_length)
attention_mask = attention_mask.view(self.max_seq_length)
return input_ids, attention_mask, token_type_ids, race_target, img_id
class BERT_MODEL_leak_data(data.Dataset):
def __init__(self, d_train, d_test, args, race_task_entries, race_words, tokenizer, max_seq_length, split):
self.task = args.task
#self.id_2_val_obj_cap_entries = id_2_val_obj_cap_entries
self.race_task_entries = race_task_entries
self.split = split
self.race_words = race_words
self.tokenizer = tokenizer
self.max_seq_length = max_seq_length
self.d_train, self.d_test = d_train, d_test
def __len__(self):
if self.split == 'train':
return len(self.d_train)
else:
return len(self.d_test)
def __getitem__(self, index):
if self.split == 'train':
entries = self.d_train
else:
entries = self.d_test
entry = entries[index]
img_id = entry['img_id']
race = entry['bb_skin']
if race == 'Light':
race_target = torch.tensor(0)
elif race == 'Dark':
race_target = torch.tensor(1)
if self.task == 'captioning':
c_pred_tokens = word_tokenize(entry['pred'].lower())
new_list = []
for t in c_pred_tokens:
if t in self.race_words:
new_list.append('[MASK]')
else:
new_list.append(t)
new_sent = ' '.join([c for c in new_list])
encoded_dict = self.tokenizer.encode_plus(new_sent, add_special_tokens=True, truncation=True, max_length=self.max_seq_length,
padding='max_length', return_attention_mask=True, return_tensors='pt')
elif self.task == 'vqa':
masked_model_concat_sent = entry['masked_model_concat_sent']
masked_model_concat_sent = masked_model_concat_sent.replace('genderword', '[MASK]') #for BERT
encoded_dict = self.tokenizer.encode_plus(masked_model_concat_sent, add_special_tokens=True, truncation=True, max_length=self.max_seq_length,
padding='max_length', return_attention_mask=True, return_tensors='pt')
input_ids = encoded_dict['input_ids']
attention_mask = encoded_dict['attention_mask']
token_type_ids = encoded_dict['token_type_ids']
token_type_ids = token_type_ids.view(self.max_seq_length)
input_ids = input_ids.view(self.max_seq_length)
attention_mask = attention_mask.view(self.max_seq_length)
return input_ids, attention_mask, token_type_ids, race_target, img_id