Skip to content

Commit 97ef591

Browse files
committed
adds an option to include EOS in transducer model training;
rewrites masked_copy_cached_state() to make it clearer and more general; code adaptation/changes according to the commits on Nov 2, 2022
1 parent db4eeb2 commit 97ef591

13 files changed

+219
-75
lines changed

espresso/criterions/ctc_loss.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
import torch.nn.functional as F
1313
from omegaconf import II
1414

15-
from fairseq import metrics, utils
15+
from fairseq import utils
1616
from fairseq.criterions import FairseqCriterion, register_criterion
1717
from fairseq.data import data_utils
1818
from fairseq.dataclass import FairseqDataclass
19+
from fairseq.logging import metrics
1920
from fairseq.tasks import FairseqTask
2021

2122
logger = logging.getLogger(__name__)

espresso/criterions/transducer_loss.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@
1111
import torch
1212
from omegaconf import II
1313

14-
from fairseq import metrics, utils
14+
from fairseq import utils
1515
from fairseq.criterions import FairseqCriterion, register_criterion
1616
from fairseq.data import data_utils
1717
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
18+
from fairseq.logging import metrics
1819
from fairseq.tasks import FairseqTask
1920

2021
logger = logging.getLogger(__name__)
@@ -36,6 +37,7 @@ class TransducerLossCriterionConfig(FairseqDataclass):
3637
default="torchaudio",
3738
metadata={"help": "choice of loss backend (native or torchaudio)"},
3839
)
40+
include_eos: bool = II("task.include_eos_in_transducer_loss")
3941

4042

4143
@register_criterion("transducer_loss", dataclass=TransducerLossCriterionConfig)
@@ -64,6 +66,7 @@ def __init__(self, cfg: TransducerLossCriterionConfig, task: FairseqTask):
6466
)
6567
self.rnnt_loss = rnnt_loss
6668

69+
self.include_eos = cfg.include_eos
6770
self.dictionary = task.target_dictionary
6871
self.prev_num_updates = -1
6972

@@ -73,13 +76,15 @@ def forward(self, model, sample, reduce=True):
7376
) # B x T x U x V, B
7477

7578
if "target_lengths" in sample:
76-
target_lengths = (
77-
sample["target_lengths"].int() - 1
78-
) # Note: ensure EOS is excluded
79+
target_lengths = sample["target_lengths"].int()
80+
if not self.include_eos:
81+
target_lengths -= 1 # excludes EOS
7982
else:
8083
target_lengths = (
8184
(
8285
(sample["target"] != self.pad_idx)
86+
if self.include_eos
87+
else (sample["target"] != self.pad_idx)
8388
& (sample["target"] != self.eos_idx)
8489
)
8590
.sum(-1)
@@ -124,7 +129,9 @@ def forward(self, model, sample, reduce=True):
124129

125130
loss = self.rnnt_loss(
126131
net_output,
127-
sample["target"][:, :-1].int().contiguous(), # exclude the last EOS column
132+
(sample["target"] if self.include_eos else sample["target"][:, :-1])
133+
.int()
134+
.contiguous(),
128135
encoder_out_lengths.int(),
129136
target_lengths,
130137
blank=self.blank_idx,

espresso/data/asr_dataset.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def collate(
2121
left_pad_source=True,
2222
left_pad_target=False,
2323
input_feeding=True,
24+
maybe_bos_idx=None,
2425
pad_to_length=None,
2526
pad_to_multiple=1,
2627
src_bucketed=False,
@@ -89,11 +90,16 @@ def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None):
8990
prev_output_tokens = merge(
9091
"target",
9192
left_pad=left_pad_target,
92-
move_eos_to_beginning=True,
93+
move_eos_to_beginning=(maybe_bos_idx is None),
9394
pad_to_length=pad_to_length["target"]
9495
if pad_to_length is not None
9596
else None,
9697
)
98+
if maybe_bos_idx is not None:
99+
all_bos_vec = prev_output_tokens.new_full((1, 1), maybe_bos_idx).expand(
100+
len(samples), 1
101+
)
102+
prev_output_tokens = torch.cat([all_bos_vec, prev_output_tokens], dim=1)
97103
else:
98104
ntokens = src_lengths.sum().item()
99105

@@ -148,6 +154,10 @@ class AsrDataset(FairseqDataset):
148154
(default: True).
149155
input_feeding (bool, optional): create a shifted version of the targets
150156
to be passed into the model for teacher forcing (default: True).
157+
prepend_bos_as_input_feeding (bool, optional): target prepended with BOS symbol
158+
(instead of moving EOS to the beginning of that) as input feeding. This is
159+
currently only for a transducer model training setting where EOS is retained
160+
in target when evaluating the loss (default: False).
151161
constraints (Tensor, optional): 2d tensor with a concatenated, zero-
152162
delimited list of constraints for each sentence.
153163
num_buckets (int, optional): if set to a value greater than 0, then
@@ -176,6 +186,7 @@ def __init__(
176186
left_pad_target=False,
177187
shuffle=True,
178188
input_feeding=True,
189+
prepend_bos_as_input_feeding=False,
179190
constraints=None,
180191
num_buckets=0,
181192
src_lang_id=None,
@@ -193,6 +204,7 @@ def __init__(
193204
self.left_pad_target = left_pad_target
194205
self.shuffle = shuffle
195206
self.input_feeding = input_feeding
207+
self.prepend_bos_as_input_feeding = prepend_bos_as_input_feeding
196208
self.constraints = constraints
197209
self.src_lang_id = src_lang_id
198210
self.tgt_lang_id = tgt_lang_id
@@ -334,6 +346,9 @@ def collater(self, samples, pad_to_length=None):
334346
left_pad_source=self.left_pad_source,
335347
left_pad_target=self.left_pad_target,
336348
input_feeding=self.input_feeding,
349+
maybe_bos_idx=self.dictionary.bos()
350+
if self.prepend_bos_as_input_feeding
351+
else None,
337352
pad_to_length=pad_to_length,
338353
pad_to_multiple=self.pad_to_multiple,
339354
src_bucketed=(self.buckets is not None),

espresso/data/feat_text_dataset.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def __init__(
5757
):
5858
super().__init__()
5959
assert len(utt_ids) == len(rxfiles)
60-
self.dtype = np.float
60+
self.dtype = float
6161
self.utt_ids = utt_ids
6262
self.rxfiles = rxfiles
6363
self.size = len(utt_ids) # number of utterances
@@ -338,7 +338,6 @@ def __init__(
338338
self, utt_ids: List[str], texts: List[str], dictionary=None, append_eos=True
339339
):
340340
super().__init__()
341-
self.dtype = np.float
342341
self.dictionary = dictionary
343342
self.append_eos = append_eos
344343
self.read_text(utt_ids, texts, dictionary)

espresso/models/speech_lstm.py

+10-22
Original file line numberDiff line numberDiff line change
@@ -1017,28 +1017,16 @@ def masked_copy_cached_state(
10171017
src_cached_state[2],
10181018
)
10191019

1020-
def masked_copy_state(state: Optional[Tensor], src_state: Optional[Tensor]):
1021-
if state is None:
1022-
assert src_state is None
1023-
return None
1024-
else:
1025-
assert (
1026-
state.size(0) == mask.size(0)
1027-
and src_state is not None
1028-
and state.size() == src_state.size()
1029-
)
1030-
state[mask, ...] = src_state[mask, ...]
1031-
return state
1032-
1033-
prev_hiddens = [
1034-
masked_copy_state(p, src_p)
1035-
for (p, src_p) in zip(prev_hiddens, src_prev_hiddens)
1036-
]
1037-
prev_cells = [
1038-
masked_copy_state(p, src_p)
1039-
for (p, src_p) in zip(prev_cells, src_prev_cells)
1040-
]
1041-
input_feed = masked_copy_state(input_feed, src_input_feed)
1020+
mask = mask.unsqueeze(1)
1021+
prev_hiddens = speech_utils.apply_to_sample_pair(
1022+
lambda x, y, z=mask: torch.where(z, x, y), src_prev_hiddens, prev_hiddens
1023+
)
1024+
prev_cells = speech_utils.apply_to_sample_pair(
1025+
lambda x, y, z=mask: torch.where(z, x, y), src_prev_cells, prev_cells
1026+
)
1027+
input_feed = speech_utils.apply_to_sample_pair(
1028+
lambda x, y, z=mask: torch.where(z, x, y), src_input_feed, input_feed
1029+
)
10421030

10431031
cached_state_new = torch.jit.annotate(
10441032
Dict[str, Optional[Tensor]],

espresso/models/transformer/speech_transformer_decoder.py

+14-24
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch.nn.functional as F
1313
from torch import Tensor
1414

15+
import espresso.tools.utils as speech_utils
1516
from espresso.models.transformer import SpeechTransformerConfig
1617
from espresso.modules import (
1718
RelativePositionalEmbedding,
@@ -430,34 +431,23 @@ def masked_copy_cached_state(
430431
F.pad(src_p, (0, 1)) for src_p in src_prev_key_padding_mask
431432
]
432433

433-
def masked_copy_state(state: Optional[Tensor], src_state: Optional[Tensor]):
434-
if state is None:
435-
assert src_state is None
436-
return None
437-
else:
438-
assert (
439-
state.size(0) == mask.size(0)
440-
and src_state is not None
441-
and state.size() == src_state.size()
442-
)
443-
state[mask, ...] = src_state[mask, ...]
444-
return state
445-
446-
prev_key = [
447-
masked_copy_state(p, src_p) for (p, src_p) in zip(prev_key, src_prev_key)
448-
]
449-
prev_value = [
450-
masked_copy_state(p, src_p)
451-
for (p, src_p) in zip(prev_value, src_prev_value)
452-
]
434+
kv_mask = mask.unsqueeze(1).unsqueeze(2).unsqueeze(3)
435+
prev_key = speech_utils.apply_to_sample_pair(
436+
lambda x, y, z=kv_mask: torch.where(z, x, y), src_prev_key, prev_key
437+
)
438+
prev_value = speech_utils.apply_to_sample_pair(
439+
lambda x, y, z=kv_mask: torch.where(z, x, y), src_prev_value, prev_value
440+
)
453441
if prev_key_padding_mask is None:
454442
prev_key_padding_mask = src_prev_key_padding_mask
455443
else:
456444
assert src_prev_key_padding_mask is not None
457-
prev_key_padding_mask = [
458-
masked_copy_state(p, src_p)
459-
for (p, src_p) in zip(prev_key_padding_mask, src_prev_key_padding_mask)
460-
]
445+
pad_mask = mask.unsqueeze(1)
446+
prev_key_padding_mask = speech_utils.apply_to_sample_pair(
447+
lambda x, y, z=pad_mask: torch.where(z, x, y),
448+
src_prev_key_padding_mask,
449+
prev_key_padding_mask,
450+
)
461451

462452
cached_state = torch.jit.annotate(
463453
Dict[str, Optional[Tensor]],

espresso/models/transformer/speech_transformer_encoder.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import torch
1111
import torch.nn as nn
12+
from torch import Tensor
1213

1314
import espresso.tools.utils as speech_utils
1415
from espresso.models.transformer import SpeechTransformerConfig
@@ -314,7 +315,9 @@ def forward_scriptable(
314315
src_tokens,
315316
~speech_utils.sequence_mask(src_lengths, src_tokens.size(1)),
316317
)
317-
has_pads = src_tokens.device.type == "xla" or encoder_padding_mask.any()
318+
has_pads: Tensor = (
319+
torch.tensor(src_tokens.device.type == "xla") or encoder_padding_mask.any()
320+
)
318321

319322
if self.fc0 is not None:
320323
x = self.dropout_module(x)
@@ -330,8 +333,9 @@ def forward_scriptable(
330333
x = self.quant_noise(x)
331334

332335
# account for padding while computing the representation
333-
if has_pads:
334-
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
336+
x = x * (
337+
1 - encoder_padding_mask.unsqueeze(-1).type_as(x) * has_pads.type_as(x)
338+
)
335339

336340
# B x T x C -> T x B x C
337341
x = x.transpose(0, 1)

0 commit comments

Comments
 (0)