forked from batra-mlp-lab/visdial
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluate.lua
84 lines (70 loc) · 2.74 KB
/
evaluate.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
require 'nn'
require 'rnn'
require 'nngraph'
utils = dofile('utils.lua');
-------------------------------------------------------------------------------
-- Input arguments and options
-------------------------------------------------------------------------------
cmd = torch.CmdLine()
cmd:text()
cmd:text('Test the VisDial model for retrieval')
cmd:text()
cmd:text('Options')
-- Data input settings
cmd:option('-inputImg','data/data_img.h5','h5file path with image feature')
cmd:option('-inputQues','data/visdial_data.h5','h5file file with preprocessed questions')
cmd:option('-inputJson','data/visdial_params.json','json path with info and vocab')
cmd:option('-loadPath', 'checkpoints/model.t7', 'path to saved model')
-- optimization params
cmd:option('-batchSize', 200, 'Batch size (number of threads) (Adjust base on GRAM)');
cmd:option('-gpuid', 0, 'GPU id to use')
cmd:option('-backend', 'cudnn', 'nn|cudnn')
local opt = cmd:parse(arg);
print(opt)
-- seed for reproducibility
torch.manualSeed(1234);
-- set default tensor based on gpu usage
if opt.gpuid >= 0 then
require 'cutorch'
require 'cunn'
if opt.backend == 'cudnn' then require 'cudnn' end
cutorch.setDevice(opt.gpuid+1)
cutorch.manualSeed(1234)
torch.setdefaulttensortype('torch.CudaTensor');
else
torch.setdefaulttensortype('torch.FloatTensor');
end
------------------------------------------------------------------------
-- Read saved model and parameters
------------------------------------------------------------------------
local savedModel = torch.load(opt.loadPath)
-- transfer all options to model
local modelParams = savedModel.modelParams
opt.imgNorm = modelParams.imgNorm
opt.encoder = modelParams.encoder
opt.decoder = modelParams.decoder
modelParams.gpuid = opt.gpuid
-- add flags for various configurations
-- additionally check if its imitation of discriminative model
if string.match(opt.encoder, 'hist') then
opt.useHistory = true;
end
if string.match(opt.encoder, 'im') then opt.useIm = true; end
------------------------------------------------------------------------
-- Loading dataset
------------------------------------------------------------------------
local dataloader = dofile('dataloader.lua')
dataloader:initialize(opt, {'val'});
collectgarbage();
------------------------------------------------------------------------
-- Setup the model
------------------------------------------------------------------------
require 'model'
local model = Model(modelParams)
-- copy the weights from loaded model
model.wrapperW:copy(savedModel.modelW);
------------------------------------------------------------------------
-- Evaluation
------------------------------------------------------------------------
print('Evaluating..')
model:retrieve(dataloader, 'val');