Skip to content

Commit 660facf

Browse files
committed
allows dictionary files w/o the counts column; rename task's
--max-num-expansions-per-step to --transducer-max-num-expansions-per-step (same as generation's) and its default is 20; prints out word counts after WER evaluation; fixes decoding log write out
1 parent e0e61e2 commit 660facf

File tree

4 files changed

+41
-11
lines changed

4 files changed

+41
-11
lines changed

espresso/speech_recognize.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def _main(cfg, output_file):
6464
datefmt="%Y-%m-%d %H:%M:%S",
6565
level=os.environ.get("LOGLEVEL", "INFO").upper(),
6666
stream=output_file,
67+
force=True,
6768
)
6869
logger = logging.getLogger("espresso.speech_recognize")
6970
if output_file is not sys.stdout: # also print to stdout
@@ -359,8 +360,8 @@ def decode_fn(x):
359360
with open(
360361
os.path.join(cfg.common_eval.results_path, fn), "w", encoding="utf-8"
361362
) as f:
362-
res = "WER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%".format(
363-
*(scorer.wer())
363+
res = "WER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%, #words={:d}".format(
364+
*(scorer.wer()), scorer.tot_word_count()
364365
)
365366
logger.info(header + res)
366367
f.write(res + "\n")
@@ -370,8 +371,8 @@ def decode_fn(x):
370371
with open(
371372
os.path.join(cfg.common_eval.results_path, fn), "w", encoding="utf-8"
372373
) as f:
373-
res = "CER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%".format(
374-
*(scorer.cer())
374+
res = "CER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%, #chars={:d}".format(
375+
*(scorer.cer()), scorer.tot_char_count()
375376
)
376377
logger.info(" " * len(header) + res)
377378
f.write(res + "\n")

espresso/tasks/speech_recognition.py

+34-6
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import json
88
import logging
99
import os
10+
import tempfile
1011
from argparse import Namespace
1112
from collections import OrderedDict
1213
from dataclasses import dataclass, field
@@ -84,11 +85,11 @@ class SpeechRecognitionEspressoConfig(FairseqDataclass):
8485
"moving EOS to the beginning of that) as input feeding"
8586
},
8687
)
87-
max_num_expansions_per_step: int = field(
88-
default=2,
88+
transducer_max_num_expansions_per_step: int = field(
89+
default=II("generation.transducer_max_num_expansions_per_step"),
8990
metadata={
9091
"help": "the maximum number of non-blank expansions in a single "
91-
"time step of decoding; only relavant when training with transducer loss"
92+
"time step of decoding in validation; only relavant when training with transducer loss"
9293
},
9394
)
9495
specaugment_config: Optional[str] = field(
@@ -340,6 +341,7 @@ def setup_task(cls, cfg: SpeechRecognitionEspressoConfig, **kwargs):
340341
"""
341342
# load dictionaries
342343
dict_path = os.path.join(cfg.data, "dict.txt") if cfg.dict is None else cfg.dict
344+
dict_path = cls._maybe_add_pseudo_counts_to_dict(dict_path)
343345
enable_blank = (
344346
True if cfg.criterion_name in ["transducer_loss", "ctc_loss"] else False
345347
)
@@ -376,13 +378,39 @@ def setup_task(cls, cfg: SpeechRecognitionEspressoConfig, **kwargs):
376378
feat_dim = src_dataset.feat_dim
377379

378380
if cfg.word_dict is not None:
379-
word_dict = cls.load_dictionary(cfg.word_dict, enable_bos=False)
381+
word_dict_path = cfg.word_dict
382+
word_dict_path = cls._maybe_add_pseudo_counts_to_dict(word_dict_path)
383+
word_dict = cls.load_dictionary(word_dict_path, enable_bos=False)
380384
logger.info("word dictionary: {} types".format(len(word_dict)))
381385
return cls(cfg, tgt_dict, feat_dim, word_dict=word_dict)
382386

383387
else:
384388
return cls(cfg, tgt_dict, feat_dim)
385389

390+
@classmethod
391+
def _maybe_add_pseudo_counts_to_dict(cls, dict_path):
392+
with open(dict_path, "r", encoding="utf-8") as f:
393+
split_list = f.readline().rstrip().rsplit(" ", 1)
394+
if len(split_list) == 2:
395+
try:
396+
int(split_list[1])
397+
return dict_path
398+
except ValueError:
399+
pass
400+
logger.info(f"No counts detected in {dict_path}. Adding pseudo counts...")
401+
with open(dict_path, "r", encoding="utf-8") as fin, tempfile.NamedTemporaryFile(
402+
"w", encoding="utf-8", delete=False
403+
) as fout:
404+
for i, line in enumerate(fin):
405+
line = line.rstrip()
406+
if len(line) == 0:
407+
logger.warning(
408+
f"Empty at line {i+1} in the dictionary {dict_path}, skipping it"
409+
)
410+
continue
411+
print(line + " 1", file=fout)
412+
return fout.name
413+
386414
def load_dataset(
387415
self,
388416
split: str,
@@ -457,7 +485,7 @@ def build_model(self, cfg: DictConfig, from_checkpoint=False):
457485
self.decoder_for_validation = TransducerGreedyDecoder(
458486
[model],
459487
self.target_dictionary,
460-
max_num_expansions_per_step=self.cfg.max_num_expansions_per_step,
488+
max_num_expansions_per_step=self.cfg.transducer_max_num_expansions_per_step,
461489
bos=(
462490
self.target_dictionary.bos()
463491
if self.cfg.include_eos_in_transducer_loss
@@ -528,7 +556,7 @@ def build_generator(
528556
beam_size=getattr(args, "beam", 1),
529557
normalize_scores=(not getattr(args, "unnormalized", False)),
530558
max_num_expansions_per_step=getattr(
531-
args, "transducer_max_num_expansions_per_step", 2
559+
args, "transducer_max_num_expansions_per_step", 20
532560
),
533561
expansion_beta=getattr(args, "transducer_expansion_beta", 0),
534562
expansion_gamma=getattr(args, "transducer_expansion_gamma", None),

examples/asr_librispeech/run_transformer_transducer.sh

+1
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ if [ ${stage} -le 8 ]; then
261261
--num-shards 1 --shard-id 0 --dict $dict --bpe sentencepiece --sentencepiece-model ${sentencepiece_model}.model \
262262
--gen-subset $dataset --max-source-positions 9999 --max-target-positions 999 \
263263
--path $path --beam 5 --temperature 1.3 --criterion-name transducer_loss \
264+
--transducer-max-num-expansions-per-step 20 \
264265
--transducer-expansion-beta 2 --transducer-expansion-gamma 2.3 --transducer-prefix-alpha 1 \
265266
--results-path $decode_dir $opts
266267

fairseq/dataclass/configs.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1051,7 +1051,7 @@ class GenerationConfig(FairseqDataclass):
10511051
)
10521052
# for decoding transducer models
10531053
transducer_max_num_expansions_per_step: Optional[int] = field(
1054-
default=2,
1054+
default=20,
10551055
metadata={
10561056
"help": "the maximum number of non-blank expansions in a single time step"
10571057
},

0 commit comments

Comments
 (0)