Skip to content

Commit e597029

Browse files
committed
adds chunk streaming mask support for Transformer
1 parent 9944ec7 commit e597029

6 files changed

+116
-58
lines changed

espresso/models/transformer/speech_transformer_base.py

+2-20
Original file line numberDiff line numberDiff line change
@@ -126,18 +126,6 @@ def build_model(cls, cfg, task):
126126
else:
127127
transformer_encoder_input_size = task.feat_dim
128128

129-
encoder_transformer_context = speech_utils.eval_str_nested_list_or_tuple(
130-
cfg.encoder.transformer_context,
131-
type=int,
132-
)
133-
if encoder_transformer_context is not None:
134-
assert len(encoder_transformer_context) == 2
135-
for i in range(2):
136-
assert encoder_transformer_context[i] is None or (
137-
isinstance(encoder_transformer_context[i], int)
138-
and encoder_transformer_context[i] >= 0
139-
)
140-
141129
scheduled_sampling_rate_scheduler = ScheduledSamplingRateScheduler(
142130
cfg.scheduled_sampling_probs,
143131
cfg.start_scheduled_sampling_epoch,
@@ -147,7 +135,6 @@ def build_model(cls, cfg, task):
147135
cfg,
148136
pre_encoder=conv_layers,
149137
input_size=transformer_encoder_input_size,
150-
transformer_context=encoder_transformer_context,
151138
)
152139
decoder = cls.build_decoder(
153140
cfg,
@@ -162,14 +149,9 @@ def set_num_updates(self, num_updates):
162149
super().set_num_updates(num_updates)
163150

164151
@classmethod
165-
def build_encoder(
166-
cls, cfg, pre_encoder=None, input_size=83, transformer_context=None
167-
):
152+
def build_encoder(cls, cfg, pre_encoder=None, input_size=83):
168153
return SpeechTransformerEncoderBase(
169-
cfg,
170-
pre_encoder=pre_encoder,
171-
input_size=input_size,
172-
transformer_context=transformer_context,
154+
cfg, pre_encoder=pre_encoder, input_size=input_size
173155
)
174156

175157
@classmethod

espresso/models/transformer/speech_transformer_config.py

+16
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,22 @@ class SpeechEncoderConfig(SpeechEncDecBaseConfig):
6262
layer_type: LAYER_TYPE_CHOICES = field(
6363
default="transformer", metadata={"help": "layer type in encoder"}
6464
)
65+
chunk_size: int = field(
66+
default=0,
67+
metadata={"help": "chunk size of Transformer in chunk streaming mode if > 0"},
68+
)
69+
chunk_left_window: int = field(
70+
default=0,
71+
metadata={
72+
"help": "number of chunks to the left of the current chunk in chunk streaming mode"
73+
},
74+
)
75+
chunk_right_window: int = field(
76+
default=0,
77+
metadata={
78+
"help": "number of chunks to the right of the current chunk in chunk streaming mode"
79+
},
80+
)
6581
# config specific to Conformer
6682
depthwise_conv_kernel_size: int = field(
6783
default=31,

espresso/models/transformer/speech_transformer_encoder.py

+28-4
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
RelativePositionalEmbedding,
1818
TransformerWithRelativePositionalEmbeddingEncoderLayerBase,
1919
)
20+
from fairseq.data import data_utils
2021
from fairseq.distributed import fsdp_wrap
2122
from fairseq.models.transformer import Linear, TransformerEncoderBase
2223
from fairseq.modules import (
@@ -59,7 +60,6 @@ def __init__(
5960
return_fc=False,
6061
pre_encoder=None,
6162
input_size=83,
62-
transformer_context=None,
6363
):
6464
self.cfg = cfg
6565
super(TransformerEncoderBase, self).__init__(None) # no src dictionary
@@ -159,7 +159,19 @@ def __init__(
159159
else:
160160
self.layer_norm = None
161161

162-
self.transformer_context = transformer_context
162+
self.transformer_context = speech_utils.eval_str_nested_list_or_tuple(
163+
cfg.encoder.transformer_context,
164+
type=int,
165+
)
166+
if self.transformer_context is not None:
167+
assert len(self.transformer_context) == 2
168+
for i in range(2):
169+
assert self.transformer_context[i] is None or (
170+
isinstance(self.transformer_context[i], int)
171+
and self.transformer_context[i] >= 0
172+
)
173+
174+
self.num_updates = 0
163175

164176
def build_encoder_layer(
165177
self, cfg, positional_embedding: Optional[RelativePositionalEmbedding] = None
@@ -183,6 +195,10 @@ def build_encoder_layer(
183195
layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap)
184196
return layer
185197

198+
def set_num_updates(self, num_updates):
199+
self.num_updates = num_updates
200+
super().set_num_updates(num_updates)
201+
186202
def output_lengths(self, in_lengths):
187203
return (
188204
in_lengths
@@ -204,6 +220,16 @@ def get_attn_mask(self, in_lengths):
204220
`attn_mask[tgt_i, src_j] = 1` means that when calculating the
205221
embedding for `tgt_i`, we exclude (mask out) `src_j`.
206222
"""
223+
if self.cfg.encoder.chunk_size > 0:
224+
with data_utils.numpy_seed(self.num_updates):
225+
return ~speech_utils.chunk_streaming_mask(
226+
in_lengths,
227+
self.cfg.encoder.chunk_size,
228+
left_window=self.cfg.encoder.chunk_left_window,
229+
right_window=self.cfg.encoder.chunk_right_window,
230+
always_partial_in_last=(not self.training),
231+
)
232+
207233
if self.transformer_context is None or (
208234
self.transformer_context[0] is None and self.transformer_context[1] is None
209235
):
@@ -383,15 +409,13 @@ def __init__(
383409
return_fc=False,
384410
pre_encoder=None,
385411
input_size=83,
386-
transformer_context=None,
387412
):
388413
self.args = args
389414
super().__init__(
390415
SpeechTransformerConfig.from_namespace(args),
391416
return_fc=return_fc,
392417
pre_encoder=pre_encoder,
393418
input_size=input_size,
394-
transformer_context=transformer_context,
395419
)
396420

397421
def build_encoder_layer(

espresso/models/transformer/speech_transformer_encoder_model.py

-17
Original file line numberDiff line numberDiff line change
@@ -102,23 +102,10 @@ def build_model(cls, cfg, task):
102102
else:
103103
transformer_encoder_input_size = task.feat_dim
104104

105-
encoder_transformer_context = speech_utils.eval_str_nested_list_or_tuple(
106-
cfg.encoder.transformer_context,
107-
type=int,
108-
)
109-
if encoder_transformer_context is not None:
110-
assert len(encoder_transformer_context) == 2
111-
for i in range(2):
112-
assert encoder_transformer_context[i] is None or (
113-
isinstance(encoder_transformer_context[i], int)
114-
and encoder_transformer_context[i] >= 0
115-
)
116-
117105
encoder = cls.build_encoder(
118106
cfg,
119107
pre_encoder=conv_layers,
120108
input_size=transformer_encoder_input_size,
121-
transformer_context=encoder_transformer_context,
122109
vocab_size=(
123110
len(task.target_dictionary)
124111
if task.target_dictionary is not None
@@ -139,14 +126,12 @@ def build_encoder(
139126
cfg,
140127
pre_encoder=None,
141128
input_size=83,
142-
transformer_context=None,
143129
vocab_size=None,
144130
):
145131
return SpeechTransformerEncoderForPrediction(
146132
cfg,
147133
pre_encoder=pre_encoder,
148134
input_size=input_size,
149-
transformer_context=transformer_context,
150135
vocab_size=vocab_size,
151136
)
152137

@@ -174,15 +159,13 @@ def __init__(
174159
return_fc=False,
175160
pre_encoder=None,
176161
input_size=83,
177-
transformer_context=None,
178162
vocab_size=None,
179163
):
180164
super().__init__(
181165
cfg,
182166
return_fc=return_fc,
183167
pre_encoder=pre_encoder,
184168
input_size=input_size,
185-
transformer_context=transformer_context,
186169
)
187170

188171
self.fc_out = (

espresso/models/transformer/speech_transformer_transducer_base.py

+1-17
Original file line numberDiff line numberDiff line change
@@ -165,23 +165,10 @@ def build_model(cls, cfg, task):
165165
else:
166166
transformer_encoder_input_size = task.feat_dim
167167

168-
encoder_transformer_context = speech_utils.eval_str_nested_list_or_tuple(
169-
cfg.encoder.transformer_context,
170-
type=int,
171-
)
172-
if encoder_transformer_context is not None:
173-
assert len(encoder_transformer_context) == 2
174-
for i in range(2):
175-
assert encoder_transformer_context[i] is None or (
176-
isinstance(encoder_transformer_context[i], int)
177-
and encoder_transformer_context[i] >= 0
178-
)
179-
180168
encoder = cls.build_encoder(
181169
cfg,
182170
pre_encoder=conv_layers,
183171
input_size=transformer_encoder_input_size,
184-
transformer_context=encoder_transformer_context,
185172
)
186173
decoder = cls.build_decoder(cfg, tgt_dict, decoder_embed_tokens)
187174
# fsdp_wrap is a no-op when --ddp-backend != fully_sharded
@@ -206,14 +193,11 @@ def build_embedding(cls, cfg, dictionary, embed_dim, path=None):
206193
return emb
207194

208195
@classmethod
209-
def build_encoder(
210-
cls, cfg, pre_encoder=None, input_size=83, transformer_context=None
211-
):
196+
def build_encoder(cls, cfg, pre_encoder=None, input_size=83):
212197
return SpeechTransformerEncoderBase(
213198
cfg,
214199
pre_encoder=pre_encoder,
215200
input_size=input_size,
216-
transformer_context=transformer_context,
217201
)
218202

219203
@classmethod

espresso/tools/utils.py

+69
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import numpy as np
1414
import torch
15+
import torch.nn.functional as F
1516

1617
try:
1718
import kaldi_io
@@ -90,6 +91,74 @@ def sequence_mask(sequence_length, max_len=None):
9091
return seq_range_expand < seq_length_expand
9192

9293

94+
def chunk_streaming_mask(
95+
sequence_length: torch.Tensor,
96+
chunk_size: int,
97+
left_window: int = 0,
98+
right_window: int = 0,
99+
always_partial_in_last: bool = False,
100+
):
101+
"""Returns a mask for chunk streaming Transformer models.
102+
103+
Args:
104+
sequence_length (LongTensor): sequence_length of shape `(batch)`
105+
chunk_size (int): chunk size
106+
left_window (int): how many left chunks can be seen (default: 0)
107+
right_window (int): how many right chunks can be seen (default: 0)
108+
always_partial_in_last (bool): if True always makes the last chunk partial;
109+
otherwise makes either the first or last chunk have partial size randomly,
110+
which is to avoid learning to emit EOS just based on partial chunk size
111+
(default: False)
112+
113+
Returns:
114+
mask: (BoolTensor): a mask tensor of shape `(tgt_len, src_len)`, where
115+
`tgt_len` is the length of output and `src_len` is the length of input.
116+
`attn_mask[tgt_i, src_j] = True` means that when calculating the embedding
117+
for `tgt_i`, we need the embedding of `src_j`.
118+
"""
119+
120+
max_len = sequence_length.data.max()
121+
chunk_start_idx = torch.arange(
122+
0,
123+
max_len,
124+
chunk_size,
125+
dtype=sequence_length.dtype,
126+
device=sequence_length.device,
127+
) # e.g. [0,18,36,54]
128+
if not always_partial_in_last and np.random.rand() > 0.5:
129+
# either first or last chunk is partial. If only the last one is not complete, EOS is not effective
130+
chunk_start_idx = max_len - chunk_start_idx
131+
chunk_start_idx = chunk_start_idx.flip([0])
132+
chunk_start_idx = chunk_start_idx[:-1]
133+
chunk_start_idx = F.pad(chunk_start_idx, (1, 0))
134+
135+
start_pad = torch.nn.functional.pad(chunk_start_idx, (1, 0)) # [0,0,18,36,54]
136+
end_pad = torch.nn.functional.pad(
137+
chunk_start_idx, (0, 1), value=max_len
138+
) # [0,18,36,54,max_len]
139+
seq_range = torch.arange(
140+
0, max_len, dtype=sequence_length.dtype, device=sequence_length.device
141+
)
142+
idx = (
143+
(seq_range.unsqueeze(-1) >= start_pad) & (seq_range.unsqueeze(-1) < end_pad)
144+
).nonzero()[
145+
:, 1
146+
] # max_len
147+
seq_range_expand = seq_range.unsqueeze(0).expand(max_len, -1) # max_len x max_len
148+
149+
idx_left = idx - left_window
150+
idx_left[idx_left < 0] = 0
151+
boundary_left = start_pad[idx_left] # max_len
152+
mask_left = seq_range_expand >= boundary_left.unsqueeze(-1)
153+
154+
idx_right = idx + right_window
155+
idx_right[idx_right > len(chunk_start_idx)] = len(chunk_start_idx)
156+
boundary_right = end_pad[idx_right] # max_len
157+
mask_right = seq_range_expand < boundary_right.unsqueeze(-1)
158+
159+
return mask_left & mask_right
160+
161+
93162
def convert_padding_direction(
94163
src_frames,
95164
src_lengths,

0 commit comments

Comments
 (0)