Skip to content

Commit

Permalink
drop the need of cocobu_fc; use zero size tensor when use_fc or use_a…
Browse files Browse the repository at this point in the history
…tt is False.
  • Loading branch information
ruotianluo committed Jan 29, 2020
1 parent 297e9d3 commit 79b5f89
Show file tree
Hide file tree
Showing 7 changed files with 10 additions and 12 deletions.
1 change: 0 additions & 1 deletion configs/a2i2.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# base
caption_model: att2in2
input_json: data/cocotalk.json
input_fc_dir: data/cocobu_fc
input_att_dir: data/cocobu_att
input_label_h5: data/cocotalk_label.h5
learning_rate: 0.0005
Expand Down
1 change: 0 additions & 1 deletion configs/topdown.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# base
caption_model: topdown
input_json: data/cocotalk.json
input_fc_dir: data/cocobu_fc
input_att_dir: data/cocobu_att
input_label_h5: data/cocotalk_label.h5
learning_rate: 0.0005
Expand Down
1 change: 0 additions & 1 deletion configs/transformer.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ noamopt_warmup: 20000
label_smoothing: 0.0
input_json: data/cocotalk.json
input_label_h5: data/cocotalk_label.h5
input_fc_dir: data/cocobu_fc
input_att_dir: data/cocobu_att
seq_per_img: 5
batch_size: 10
Expand Down
5 changes: 1 addition & 4 deletions data/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,12 @@ Then:
python script/make_bu_data.py --output_dir data/cocobu
```

This will create `data/cocobu_fc`, `data/cocobu_att` and `data/cocobu_box`. If you want to use bottom-up feature, you can just follow the following steps and replace all cocotalk with cocobu.
This will create `data/cocobu_fc`(not necessary), `data/cocobu_att` and `data/cocobu_box`. If you want to use bottom-up feature, you can just replace all `"cocotalk"` with `"cocobu"` in the training/test scripts.

#### Download converted files

bottomup-fc: [link](https://drive.google.com/file/d/1IpjCJ5LYC4kX2krxHcPgxAIipgA8uqTU/view?usp=sharing) (The fc features here are simply the average of the attention features)

bottomup-att: [link](https://drive.google.com/file/d/1hun0tsel34aXO4CYyTRIvHJkcbZHwjrD/view?usp=sharing)


## Flickr30k.

It's similar.
Expand Down
10 changes: 7 additions & 3 deletions dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,11 +255,15 @@ def __getitem__(self, index):
# sort the features by the size of boxes
att_feat = np.stack(sorted(att_feat, key=lambda x:x[-1], reverse=True))
else:
att_feat = np.zeros((1,1,1), dtype='float32')
att_feat = np.zeros((0,0), dtype='float32')
if self.use_fc:
fc_feat = self.fc_loader.get(str(self.info['images'][ix]['id']))
try:
fc_feat = self.fc_loader.get(str(self.info['images'][ix]['id']))
except:
# Use average of attention when there is no fc provided (For bottomup feature)
fc_feat = att_feat.mean(0)
else:
fc_feat = np.zeros((1), dtype='float32')
fc_feat = np.zeros((0), dtype='float32')
if hasattr(self, 'h5_label_file'):
seq = self.get_captions(ix, self.seq_per_img)
else:
Expand Down
2 changes: 1 addition & 1 deletion models/CaptionModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def repeat_tensor(n, x):
if x is not None:
x = x.unsqueeze(1) # Bx1x...
x = x.expand(-1, n, *([-1]*len(x.shape[2:]))) # Bxnx...
x = x.reshape(-1, *x.shape[2:]) # Bnx...
x = x.reshape(x.shape[0]*n, *x.shape[2:]) # Bnx...
return x

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion models/TransformerModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def _prepare_feature(self, fc_feats, att_feats, att_masks):
att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks)
memory = self.model.encode(att_feats, att_masks)

return fc_feats[...,:1], att_feats[...,:1], memory, att_masks
return fc_feats[...,:0], att_feats[...,:0], memory, att_masks

def _prepare_feature_forward(self, att_feats, att_masks=None, seq=None):
att_feats, att_masks = self.clip_att(att_feats, att_masks)
Expand Down

0 comments on commit 79b5f89

Please sign in to comment.