forked from noahchalifour/rnnt-speech-recognition
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluate.py
63 lines (48 loc) · 2.15 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import tensorflow as tf
_has_loss_func = False
try:
from warprnnt_tensorflow import rnnt_loss
_has_loss_func = True
except ImportError:
pass
try:
from .utils.data.common import preprocess_dataset
except ImportError:
from utils.data.common import preprocess_dataset
def do_eval(model, dataset, batch_size,
shuffle_buffer_size=None, distribution_strategy=None):
_dataset = preprocess_dataset(dataset, model.vocab, batch_size,
shuffle_buffer_size=shuffle_buffer_size)
if distribution_strategy is not None:
_dataset = distribution_strategy.experimental_distribute_dataset(
_dataset)
eval_loss = tf.keras.metrics.Mean(name='eval_loss')
eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
name='eval_accuracy')
@tf.function(input_signature=[tf.TensorSpec([None, None, 80], tf.float32),
tf.TensorSpec([None, None], tf.int32),
tf.TensorSpec([None], tf.int32),
tf.TensorSpec([None], tf.int32),
tf.TensorSpec([2, None, None], tf.float32)])
def eval_step(fb, labels, fb_lengths, labels_lengths, enc_state):
pred_inp = labels[:, :-1]
pred_out = labels[:, 1:]
predictions, _ = model([fb, pred_inp, enc_state],
training=False)
if _has_loss_func:
loss = warprnnt_tensorflow.rnnt_loss(predictions,
pred_out,
fb_lengths,
labels_lengths)
else:
loss = 0
eval_loss(loss)
eval_accuracy(pred_out, predictions)
enc_state = model.initial_state(batch_size)
for (inp, tar, inp_length, tar_length) in _dataset:
if distribution_strategy is not None:
distribution_strategy.experimental_run_v2(
eval_step, args=(inp, tar, inp_length, tar_length, enc_state))
else:
eval_step(inp, tar, inp_length, tar_length, enc_state)
return eval_loss.result(), eval_accuracy.result()