forked from SeanNaren/deepspeech.torch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathWEREvaluator.lua
128 lines (110 loc) · 4.51 KB
/
WEREvaluator.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
require 'Loader'
require 'Util'
require 'Mapper'
require 'torch'
require 'xlua'
require 'nnx'
require 'cutorch'
local threads = require 'threads'
local Evaluator = require 'Evaluator'
local WEREvaluator = torch.class('WEREvaluator')
function WEREvaluator:__init(_path, mapper, testBatchSize,
logsPath, feature, dataHeight, modelname)
self.testBatchSize = testBatchSize
self.feature = feature
self.mapper = mapper
self.logsPath = logsPath
self.suffix = '_' .. os.date('%Y%m%d_%H%M%S')
self.testLoader = Loader(_path, testBatchSize, feature, dataHeight, modelname)
self.ctc = nn.CTCCriterion():cuda()
end
function WEREvaluator:predicTrans(src, nGPU)
local gpu_number = nGPU or 1
return src:view(-1, self.testBatchSize / gpu_number, src:size(2)):transpose(1,2)
end
function WEREvaluator:getWER(gpu, model, calSizeOfSequences, verbose, currentIteration)
--[[
load test_iter*batch_size data point from test set; compute average WER
input:
verbose:if true then print WER and predicted strings for each data to log
--]]
local cumWER = 0
local inputs = torch.Tensor()
if (gpu) then
inputs = inputs:cuda()
end
local specBuf, labelBuf, sizesBuf
if verbose then
local f = assert(io.open(self.logsPath .. 'WER_Test' .. self.suffix .. '.log', 'a'),
"Could not create validation test logs, does the folder "
.. self.logsPath .. " exist?")
f:write('======================== BEGIN WER TEST currentIteration: '
.. currentIteration .. ' =========================\n')
f:close()
end
local werPredictions = {} -- stores the predictions to order for log.
local loss = 0
local N = 0
-- ======================= for every test iteration ==========================
for n, sample in self.testLoader:nxt_batch() do
-- get buf and fetch next one
local inputs, sizes, targets, labelcnt
sizes = calSizeOfSequences(sample.sizes)
targets = sample.label
labelcnt = sample.labelcnt
if gpu then
inputs = sample.inputs:cuda()
sizes = sizes:cuda()
end
local predictions = model:forward(inputs)
if type(predictions) == 'table' then
local temp = self:predicTrans(predictions[1], #predictions)
for k = 2, #predictions do
temp = torch.cat(temp, self:predicTrans(predictions[k], #predictions), 1)
end
predictions = temp
else
predictions = self:predicTrans(predictions)
end
-- =============== evaluate CTC ==================
local predictions_ctc = predictions:transpose(1,2)
self.ctc:forward(predictions_ctc, targets, sizes)
local batchLoss = self.ctc.output / labelcnt
loss = loss + batchLoss
-- =============== evaluate WER ==================
local batchWER = 0
for j = 1, self.testBatchSize do
local prediction_single = predictions[j]:narrow(1, 1, sizes[j])
local predict_tokens = Evaluator.predict2tokens(prediction_single, self.mapper)
local WER = Evaluator.sequenceErrorRate(targets[j], predict_tokens)
cumWER = cumWER + WER
batchWER = batchWER + WER
table.insert(werPredictions, { wer = WER * 100, target = self:tokens2text(targets[j]), prediction = self:tokens2text(predict_tokens) })
end
print(('Testing | Iter: %d, Error: %1.3f WER: %2.2f%%'):format(n, batchLoss, batchWER/self.testBatchSize*100))
N = N + 1
end
local function comp(a, b) return a.wer < b.wer end
-- table.sort(werPredictions, comp)
if verbose then
for index, werPrediction in ipairs(werPredictions) do
local f = assert(io.open(self.logsPath .. 'WER_Test' .. self.suffix .. '.log', 'a'))
f:write(string.format("WER = %.2f%% | Text = \"%s\" | Predict = \"%s\"\n",
werPrediction.wer, werPrediction.target, werPrediction.prediction))
f:close()
end
end
local averageWER = cumWER / (N * self.testBatchSize)
loss = loss / N
local f = assert(io.open(self.logsPath .. 'WER_Test' .. self.suffix .. '.log', 'a'))
f:write(string.format("Average WER = %.2f%%\n", averageWER * 100))
f:close()
return {loss = loss, WER = averageWER}
end
function WEREvaluator:tokens2text(tokens)
local text = ""
for i, t in ipairs(tokens) do
text = text .. self.mapper.token2alphabet[tokens[i]]
end
return text
end