|
7 | 7 | import json
|
8 | 8 | import logging
|
9 | 9 | import os
|
| 10 | +import tempfile |
10 | 11 | from argparse import Namespace
|
11 | 12 | from collections import OrderedDict
|
12 | 13 | from dataclasses import dataclass, field
|
@@ -84,11 +85,11 @@ class SpeechRecognitionEspressoConfig(FairseqDataclass):
|
84 | 85 | "moving EOS to the beginning of that) as input feeding"
|
85 | 86 | },
|
86 | 87 | )
|
87 |
| - max_num_expansions_per_step: int = field( |
88 |
| - default=2, |
| 88 | + transducer_max_num_expansions_per_step: int = field( |
| 89 | + default=II("generation.transducer_max_num_expansions_per_step"), |
89 | 90 | metadata={
|
90 | 91 | "help": "the maximum number of non-blank expansions in a single "
|
91 |
| - "time step of decoding; only relavant when training with transducer loss" |
| 92 | + "time step of decoding in validation; only relavant when training with transducer loss" |
92 | 93 | },
|
93 | 94 | )
|
94 | 95 | specaugment_config: Optional[str] = field(
|
@@ -340,6 +341,7 @@ def setup_task(cls, cfg: SpeechRecognitionEspressoConfig, **kwargs):
|
340 | 341 | """
|
341 | 342 | # load dictionaries
|
342 | 343 | dict_path = os.path.join(cfg.data, "dict.txt") if cfg.dict is None else cfg.dict
|
| 344 | + dict_path = cls._maybe_add_pseudo_counts_to_dict(dict_path) |
343 | 345 | enable_blank = (
|
344 | 346 | True if cfg.criterion_name in ["transducer_loss", "ctc_loss"] else False
|
345 | 347 | )
|
@@ -376,13 +378,39 @@ def setup_task(cls, cfg: SpeechRecognitionEspressoConfig, **kwargs):
|
376 | 378 | feat_dim = src_dataset.feat_dim
|
377 | 379 |
|
378 | 380 | if cfg.word_dict is not None:
|
379 |
| - word_dict = cls.load_dictionary(cfg.word_dict, enable_bos=False) |
| 381 | + word_dict_path = cfg.word_dict |
| 382 | + word_dict_path = cls._maybe_add_pseudo_counts_to_dict(word_dict_path) |
| 383 | + word_dict = cls.load_dictionary(word_dict_path, enable_bos=False) |
380 | 384 | logger.info("word dictionary: {} types".format(len(word_dict)))
|
381 | 385 | return cls(cfg, tgt_dict, feat_dim, word_dict=word_dict)
|
382 | 386 |
|
383 | 387 | else:
|
384 | 388 | return cls(cfg, tgt_dict, feat_dim)
|
385 | 389 |
|
| 390 | + @classmethod |
| 391 | + def _maybe_add_pseudo_counts_to_dict(cls, dict_path): |
| 392 | + with open(dict_path, "r", encoding="utf-8") as f: |
| 393 | + split_list = f.readline().rstrip().rsplit(" ", 1) |
| 394 | + if len(split_list) == 2: |
| 395 | + try: |
| 396 | + int(split_list[1]) |
| 397 | + return dict_path |
| 398 | + except ValueError: |
| 399 | + pass |
| 400 | + logger.info(f"No counts detected in {dict_path}. Adding pseudo counts...") |
| 401 | + with open(dict_path, "r", encoding="utf-8") as fin, tempfile.NamedTemporaryFile( |
| 402 | + "w", encoding="utf-8", delete=False |
| 403 | + ) as fout: |
| 404 | + for i, line in enumerate(fin): |
| 405 | + line = line.rstrip() |
| 406 | + if len(line) == 0: |
| 407 | + logger.warning( |
| 408 | + f"Empty at line {i+1} in the dictionary {dict_path}, skipping it" |
| 409 | + ) |
| 410 | + continue |
| 411 | + print(line + " 1", file=fout) |
| 412 | + return fout.name |
| 413 | + |
386 | 414 | def load_dataset(
|
387 | 415 | self,
|
388 | 416 | split: str,
|
@@ -457,7 +485,7 @@ def build_model(self, cfg: DictConfig, from_checkpoint=False):
|
457 | 485 | self.decoder_for_validation = TransducerGreedyDecoder(
|
458 | 486 | [model],
|
459 | 487 | self.target_dictionary,
|
460 |
| - max_num_expansions_per_step=self.cfg.max_num_expansions_per_step, |
| 488 | + max_num_expansions_per_step=self.cfg.transducer_max_num_expansions_per_step, |
461 | 489 | bos=(
|
462 | 490 | self.target_dictionary.bos()
|
463 | 491 | if self.cfg.include_eos_in_transducer_loss
|
@@ -528,7 +556,7 @@ def build_generator(
|
528 | 556 | beam_size=getattr(args, "beam", 1),
|
529 | 557 | normalize_scores=(not getattr(args, "unnormalized", False)),
|
530 | 558 | max_num_expansions_per_step=getattr(
|
531 |
| - args, "transducer_max_num_expansions_per_step", 2 |
| 559 | + args, "transducer_max_num_expansions_per_step", 20 |
532 | 560 | ),
|
533 | 561 | expansion_beta=getattr(args, "transducer_expansion_beta", 0),
|
534 | 562 | expansion_gamma=getattr(args, "transducer_expansion_gamma", None),
|
|
0 commit comments