@@ -21,6 +21,7 @@ def collate(
21
21
left_pad_source = True ,
22
22
left_pad_target = False ,
23
23
input_feeding = True ,
24
+ maybe_bos_idx = None ,
24
25
pad_to_length = None ,
25
26
pad_to_multiple = 1 ,
26
27
src_bucketed = False ,
@@ -89,11 +90,16 @@ def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None):
89
90
prev_output_tokens = merge (
90
91
"target" ,
91
92
left_pad = left_pad_target ,
92
- move_eos_to_beginning = True ,
93
+ move_eos_to_beginning = ( maybe_bos_idx is None ) ,
93
94
pad_to_length = pad_to_length ["target" ]
94
95
if pad_to_length is not None
95
96
else None ,
96
97
)
98
+ if maybe_bos_idx is not None :
99
+ all_bos_vec = prev_output_tokens .new_full ((1 , 1 ), maybe_bos_idx ).expand (
100
+ len (samples ), 1
101
+ )
102
+ prev_output_tokens = torch .cat ([all_bos_vec , prev_output_tokens ], dim = 1 )
97
103
else :
98
104
ntokens = src_lengths .sum ().item ()
99
105
@@ -148,6 +154,10 @@ class AsrDataset(FairseqDataset):
148
154
(default: True).
149
155
input_feeding (bool, optional): create a shifted version of the targets
150
156
to be passed into the model for teacher forcing (default: True).
157
+ prepend_bos_as_input_feeding (bool, optional): target prepended with BOS symbol
158
+ (instead of moving EOS to the beginning of that) as input feeding. This is
159
+ currently only for a transducer model training setting where EOS is retained
160
+ in target when evaluating the loss (default: False).
151
161
constraints (Tensor, optional): 2d tensor with a concatenated, zero-
152
162
delimited list of constraints for each sentence.
153
163
num_buckets (int, optional): if set to a value greater than 0, then
@@ -176,6 +186,7 @@ def __init__(
176
186
left_pad_target = False ,
177
187
shuffle = True ,
178
188
input_feeding = True ,
189
+ prepend_bos_as_input_feeding = False ,
179
190
constraints = None ,
180
191
num_buckets = 0 ,
181
192
src_lang_id = None ,
@@ -193,6 +204,7 @@ def __init__(
193
204
self .left_pad_target = left_pad_target
194
205
self .shuffle = shuffle
195
206
self .input_feeding = input_feeding
207
+ self .prepend_bos_as_input_feeding = prepend_bos_as_input_feeding
196
208
self .constraints = constraints
197
209
self .src_lang_id = src_lang_id
198
210
self .tgt_lang_id = tgt_lang_id
@@ -334,6 +346,9 @@ def collater(self, samples, pad_to_length=None):
334
346
left_pad_source = self .left_pad_source ,
335
347
left_pad_target = self .left_pad_target ,
336
348
input_feeding = self .input_feeding ,
349
+ maybe_bos_idx = self .dictionary .bos ()
350
+ if self .prepend_bos_as_input_feeding
351
+ else None ,
337
352
pad_to_length = pad_to_length ,
338
353
pad_to_multiple = self .pad_to_multiple ,
339
354
src_bucketed = (self .buckets is not None ),
0 commit comments