-
Notifications
You must be signed in to change notification settings - Fork 330
/
Copy pathfairseq_layers.py
611 lines (539 loc) · 23.4 KB
/
fairseq_layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
"""
Copyright 2021 The LightSeq Team
Copyright Facebook Fairseq
We use layers from Facebook Fairseq as our baseline for unit test
"""
from typing import Dict, List, Optional
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq.modules import LayerNorm, MultiheadAttention
from fairseq.modules.fairseq_dropout import FairseqDropout
from fairseq.modules.quant_noise import quant_noise
from torch import Tensor
class FSTransformerEncoderLayer(nn.Module):
"""Encoder layer implemented by fairseq.
This version only removes the "args" parameter, no other changes
In the original paper each operation (multi-head attention or FFN) is
postprocessed with: `dropout -> add residual -> layernorm`.
In the tensor2tensor code they suggest that learning is more robust when
preprocessing each layer with layernorm and postprocessing with:
`dropout -> add residual`. We default to the approach in the paper, but the
tensor2tensor approach can be enabled by setting
normalize_before to True.
"""
def __init__(
self,
embed_dim,
ffn_embed_dim,
nhead,
dropout,
attn_dropout,
activation_dropout,
normalize_before=True,
activation_fn="relu",
quant_noise=0,
quant_noise_block_size=8,
):
super().__init__()
self.embed_dim = embed_dim
self.quant_noise = quant_noise
self.quant_noise_block_size = quant_noise_block_size
self.self_attn = self.build_self_attention(self.embed_dim, nhead, attn_dropout)
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.dropout_module = FairseqDropout(
dropout, module_name=self.__class__.__name__
)
self.activation_fn = utils.get_activation_fn(activation=activation_fn)
activation_dropout_p = activation_dropout
self.activation_dropout_module = FairseqDropout(
float(activation_dropout_p), module_name=self.__class__.__name__
)
self.normalize_before = normalize_before
self.fc1 = self.build_fc1(
self.embed_dim,
ffn_embed_dim,
self.quant_noise,
self.quant_noise_block_size,
)
self.fc2 = self.build_fc2(
ffn_embed_dim,
self.embed_dim,
self.quant_noise,
self.quant_noise_block_size,
)
self.final_layer_norm = LayerNorm(self.embed_dim)
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
return quant_noise(
nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size
)
def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size):
return quant_noise(
nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size
)
def build_self_attention(self, embed_dim, nhead, attn_dropout):
return MultiheadAttention(
embed_dim,
nhead,
dropout=attn_dropout,
self_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
)
def residual_connection(self, x, residual):
return residual + x
def upgrade_state_dict_named(self, state_dict, name):
"""
Rename layer norm states from `...layer_norms.0.weight` to
`...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to
`...final_layer_norm.weight`
"""
layer_norm_map = {"0": "self_attn_layer_norm", "1": "final_layer_norm"}
for old, new in layer_norm_map.items():
for m in ("weight", "bias"):
k = "{}.layer_norms.{}.{}".format(name, old, m)
if k in state_dict:
state_dict["{}.{}.{}".format(name, new, m)] = state_dict[k]
del state_dict[k]
def forward(self, x, encoder_padding_mask, attn_mask: Optional[Tensor] = None):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor): binary ByteTensor of shape
`(batch, seq_len)` where padding elements are indicated by ``1``.
attn_mask (ByteTensor): binary tensor of shape `(tgt_len, src_len)`,
where `tgt_len` is the length of output and `src_len` is the
length of input, though here both are equal to `seq_len`.
`attn_mask[tgt_i, src_j] = 1` means that when calculating the
embedding for `tgt_i`, we exclude (mask out) `src_j`. This is
useful for strided self-attention.
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
# anything in original attn_mask = 1, becomes -1e8
# anything in original attn_mask = 0, becomes 0
# Note that we cannot use -inf here, because at some edge cases,
# the attention weight (before softmax) for some padded element in query
# will become -inf, which results in NaN in model parameters
if attn_mask is not None:
attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8)
residual = x
if self.normalize_before:
x = self.self_attn_layer_norm(x)
x, _ = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=encoder_padding_mask,
attn_mask=attn_mask,
)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = self.activation_dropout_module(x)
x = self.fc2(x)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.final_layer_norm(x)
return x
class FSTransformerDecoderLayer(nn.Module):
"""Decoder layer implemented by fairseq.
This version only removes the "args" parameter, no other changes
"""
def __init__(
self,
embed_dim,
ffn_embed_dim,
nhead,
encoder_embed_dim,
dropout,
attn_dropout,
activation_dropout,
normalize_before=True,
activation_fn="relu",
quant_noise=0,
quant_noise_block_size=8,
cross_self_attention=False,
no_encoder_attn=False,
add_bias_kv=False,
add_zero_attn=False,
):
super().__init__()
self.embed_dim = embed_dim
self.dropout_module = FairseqDropout(
dropout, module_name=self.__class__.__name__
)
self.quant_noise = quant_noise
self.quant_noise_block_size = quant_noise_block_size
self.cross_self_attention = cross_self_attention
self.self_attn = self.build_self_attention(
self.embed_dim,
nhead,
attn_dropout,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
)
self.activation_fn = utils.get_activation_fn(activation=activation_fn)
activation_dropout_p = activation_dropout
self.activation_dropout_module = FairseqDropout(
float(activation_dropout_p), module_name=self.__class__.__name__
)
self.normalize_before = normalize_before
export = False
self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
if no_encoder_attn:
self.encodec_attn = None
self.encodec_attn_layer_norm = None
else:
self.encodec_attn = self.build_encoder_attention(
self.embed_dim, encoder_embed_dim, attn_dropout, nhead
)
self.encodec_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
self.fc1 = self.build_fc1(
self.embed_dim,
ffn_embed_dim,
self.quant_noise,
self.quant_noise_block_size,
)
self.fc2 = self.build_fc2(
ffn_embed_dim,
self.embed_dim,
self.quant_noise,
self.quant_noise_block_size,
)
self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
self.need_attn = True
self.onnx_trace = False
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)
def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size):
return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)
def build_self_attention(
self, embed_dim, nhead, attn_dropout, add_bias_kv=False, add_zero_attn=False
):
return MultiheadAttention(
embed_dim,
nhead,
dropout=attn_dropout,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
self_attention=not self.cross_self_attention,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
)
def build_encoder_attention(
self, embed_dim, encoder_embed_dim, attn_dropout, nhead
):
return MultiheadAttention(
embed_dim,
nhead,
kdim=encoder_embed_dim,
vdim=encoder_embed_dim,
dropout=attn_dropout,
encoder_decoder_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
)
def prepare_for_onnx_export_(self):
self.onnx_trace = True
def residual_connection(self, x, residual):
return residual + x
def forward(
self,
x,
encoder_out: Optional[torch.Tensor] = None,
encoder_padding_mask: Optional[torch.Tensor] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
prev_self_attn_state: Optional[List[torch.Tensor]] = None,
prev_attn_state: Optional[List[torch.Tensor]] = None,
self_attn_mask: Optional[torch.Tensor] = None,
self_attn_padding_mask: Optional[torch.Tensor] = None,
need_attn: bool = False,
need_head_weights: bool = False,
):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor, optional): binary
ByteTensor of shape `(batch, src_len)` where padding
elements are indicated by ``1``.
need_attn (bool, optional): return attention weights
need_head_weights (bool, optional): return attention weights
for each head (default: return average over heads).
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
if need_head_weights:
need_attn = True
residual = x
if self.normalize_before:
x = self.self_attn_layer_norm(x)
if prev_self_attn_state is not None:
prev_key, prev_value = prev_self_attn_state[:2]
saved_state: Dict[str, Optional[Tensor]] = {
"prev_key": prev_key,
"prev_value": prev_value,
}
if len(prev_self_attn_state) >= 3:
saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
assert incremental_state is not None
self.self_attn._set_input_buffer(incremental_state, saved_state)
_self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state)
if self.cross_self_attention and not (
incremental_state is not None
and _self_attn_input_buffer is not None
and "prev_key" in _self_attn_input_buffer
):
if self_attn_mask is not None:
assert encoder_out is not None
self_attn_mask = torch.cat(
(x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1
)
if self_attn_padding_mask is not None:
if encoder_padding_mask is None:
assert encoder_out is not None
encoder_padding_mask = self_attn_padding_mask.new_zeros(
encoder_out.size(1), encoder_out.size(0)
)
self_attn_padding_mask = torch.cat(
(encoder_padding_mask, self_attn_padding_mask), dim=1
)
assert encoder_out is not None
y = torch.cat((encoder_out, x), dim=0)
else:
y = x
x, attn = self.self_attn(
query=x,
key=y,
value=y,
key_padding_mask=self_attn_padding_mask,
incremental_state=incremental_state,
need_weights=False,
attn_mask=self_attn_mask,
)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
if self.encodec_attn is not None and encoder_out is not None:
residual = x
if self.normalize_before:
x = self.encodec_attn_layer_norm(x)
if prev_attn_state is not None:
prev_key, prev_value = prev_attn_state[:2]
saved_state: Dict[str, Optional[Tensor]] = {
"prev_key": prev_key,
"prev_value": prev_value,
}
if len(prev_attn_state) >= 3:
saved_state["prev_key_padding_mask"] = prev_attn_state[2]
assert incremental_state is not None
self.encodec_attn._set_input_buffer(incremental_state, saved_state)
x, attn = self.encodec_attn(
query=x,
key=encoder_out,
value=encoder_out,
key_padding_mask=encoder_padding_mask,
incremental_state=incremental_state,
static_kv=True,
need_weights=need_attn or (not self.training and self.need_attn),
need_head_weights=need_head_weights,
)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.encodec_attn_layer_norm(x)
residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = self.activation_dropout_module(x)
x = self.fc2(x)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.final_layer_norm(x)
if self.onnx_trace and incremental_state is not None:
saved_state = self.self_attn._get_input_buffer(incremental_state)
assert saved_state is not None
if self_attn_padding_mask is not None:
self_attn_state = [
saved_state["prev_key"],
saved_state["prev_value"],
saved_state["prev_key_padding_mask"],
]
else:
self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]]
return x, attn, self_attn_state
return x, attn, None
def make_generation_fast_(self, need_attn: bool = False, **kwargs):
self.need_attn = need_attn
class SinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length.
Padding symbols are ignored.
"""
def __init__(self, embedding_dim, padding_idx, init_size=1024, fp16=False):
super().__init__()
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.weights = SinusoidalPositionalEmbedding.get_embedding(
init_size, embedding_dim, padding_idx
).to(torch.device("cuda:0"))
if fp16:
self.weights = self.weights.to(torch.half)
@staticmethod
def get_embedding(
num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None
):
"""Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(
1
) * emb.unsqueeze(0)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(
num_embeddings, -1
)
if embedding_dim % 2 == 1:
# zero pad
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
return emb
def make_positions(self, tensor, padding_idx):
mask = tensor.ne(padding_idx).int()
return ((torch.cumsum(mask, dim=1).type_as(mask) - 1) * mask).long()
def forward(
self,
input,
incremental_state=None,
timestep=None,
positions=None,
):
"""Input is expected to be of size [bsz x seqlen]."""
bsz, seq_len = input.size(0), input.size(1)
positions = self.make_positions(input, self.padding_idx)
mask = (
torch.ne(input, self.padding_idx)
.unsqueeze(2)
.expand(bsz, seq_len, self.embedding_dim)
)
return (
self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1)
* mask
).detach()
class FSTransformerEmbeddingLayer(nn.Module):
def __init__(
self, vocab_size, embedding_dim, max_seq_len, padding_idx, dropout, fp16
):
super().__init__()
self.embeddings = nn.Embedding(
vocab_size, embedding_dim, padding_idx=padding_idx
)
nn.init.normal_(self.embeddings.weight, mean=0, std=embedding_dim ** -0.5)
nn.init.constant_(self.embeddings.weight[padding_idx], 0)
self.embeddings.to(
torch.device("cuda:0"), dtype=(torch.half if fp16 else torch.float)
)
self.embed_positions = SinusoidalPositionalEmbedding(
embedding_dim, padding_idx, max_seq_len, fp16
).to(torch.device("cuda:0"))
self.embedding_dim = embedding_dim
self.dropout = dropout
def forward(self, input):
x = self.embeddings(input)
x = math.sqrt(self.embedding_dim) * x
x += self.embed_positions(input)
x = F.dropout(x, p=self.dropout, training=True)
return x
class FSCrossEntropyLayer(nn.Module):
def __init__(self, epsilon, ignore_index):
super().__init__()
self.epsilon = epsilon
self.ignore_index = ignore_index
def label_smoothed_nll_loss(self, lprobs, target, reduce=True):
if target.dim() == lprobs.dim() - 1:
target = target.unsqueeze(-1)
nll_loss = -lprobs.gather(dim=-1, index=target)
smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
if self.ignore_index is not None:
pad_mask = target.eq(self.ignore_index)
nll_loss.masked_fill_(pad_mask, 0.0)
smooth_loss.masked_fill_(pad_mask, 0.0)
else:
nll_loss = nll_loss.squeeze(-1)
smooth_loss = smooth_loss.squeeze(-1)
if reduce:
nll_loss = nll_loss.sum()
smooth_loss = smooth_loss.sum()
eps_i = self.epsilon / (lprobs.size(-1) - 1)
loss = (1.0 - self.epsilon - eps_i) * nll_loss + eps_i * smooth_loss
return loss, nll_loss
def forward(self, inputs, targets):
x = torch.nn.functional.log_softmax(inputs, dim=-1, dtype=torch.float32)
loss, nll_loss = self.label_smoothed_nll_loss(x, targets)
loss = loss.to(inputs)
nll_loss = nll_loss.to(inputs)
return loss, nll_loss
def get_fairseq_enc_params(fairseq_layer):
initial_weights = []
initial_biases = []
initial_weights.append(fairseq_layer.self_attn.q_proj.weight.detach().clone())
initial_biases.append(fairseq_layer.self_attn.q_proj.bias.detach().clone())
initial_weights.append(fairseq_layer.self_attn.k_proj.weight.detach().clone())
initial_biases.append(fairseq_layer.self_attn.k_proj.bias.detach().clone())
initial_weights.append(fairseq_layer.self_attn.v_proj.weight.detach().clone())
initial_biases.append(fairseq_layer.self_attn.v_proj.bias.detach().clone())
initial_weights.append(fairseq_layer.self_attn.out_proj.weight.detach().clone())
initial_biases.append(fairseq_layer.self_attn.out_proj.bias.detach().clone())
initial_weights.append(fairseq_layer.self_attn_layer_norm.weight.detach().clone())
initial_biases.append(fairseq_layer.self_attn_layer_norm.bias.detach().clone())
initial_weights.append(fairseq_layer.fc1.weight.detach().clone())
initial_biases.append(fairseq_layer.fc1.bias.detach().clone())
initial_weights.append(fairseq_layer.fc2.weight.detach().clone())
initial_biases.append(fairseq_layer.fc2.bias.detach().clone())
initial_weights.append(fairseq_layer.final_layer_norm.weight.detach().clone())
initial_biases.append(fairseq_layer.final_layer_norm.bias.detach().clone())
return initial_weights, initial_biases
def get_fairseq_dec_params(fairseq_layer):
initial_weights = []
initial_biases = []
initial_weights.append(fairseq_layer.self_attn.q_proj.weight.detach().clone())
initial_biases.append(fairseq_layer.self_attn.q_proj.bias.detach().clone())
initial_weights.append(fairseq_layer.self_attn.k_proj.weight.detach().clone())
initial_biases.append(fairseq_layer.self_attn.k_proj.bias.detach().clone())
initial_weights.append(fairseq_layer.self_attn.v_proj.weight.detach().clone())
initial_biases.append(fairseq_layer.self_attn.v_proj.bias.detach().clone())
initial_weights.append(fairseq_layer.self_attn.out_proj.weight.detach().clone())
initial_biases.append(fairseq_layer.self_attn.out_proj.bias.detach().clone())
initial_weights.append(fairseq_layer.self_attn_layer_norm.weight.detach().clone())
initial_biases.append(fairseq_layer.self_attn_layer_norm.bias.detach().clone())
initial_weights.append(fairseq_layer.encodec_attn.q_proj.weight.detach().clone())
initial_biases.append(fairseq_layer.encodec_attn.q_proj.bias.detach().clone())
initial_weights.append(fairseq_layer.encodec_attn.k_proj.weight.detach().clone())
initial_biases.append(fairseq_layer.encodec_attn.k_proj.bias.detach().clone())
initial_weights.append(fairseq_layer.encodec_attn.v_proj.weight.detach().clone())
initial_biases.append(fairseq_layer.encodec_attn.v_proj.bias.detach().clone())
initial_weights.append(fairseq_layer.encodec_attn.out_proj.weight.detach().clone())
initial_biases.append(fairseq_layer.encodec_attn.out_proj.bias.detach().clone())
initial_weights.append(
fairseq_layer.encodec_attn_layer_norm.weight.detach().clone()
)
initial_biases.append(fairseq_layer.encodec_attn_layer_norm.bias.detach().clone())
initial_weights.append(fairseq_layer.fc1.weight.detach().clone())
initial_biases.append(fairseq_layer.fc1.bias.detach().clone())
initial_weights.append(fairseq_layer.fc2.weight.detach().clone())
initial_biases.append(fairseq_layer.fc2.bias.detach().clone())
initial_weights.append(fairseq_layer.final_layer_norm.weight.detach().clone())
initial_biases.append(fairseq_layer.final_layer_norm.bias.detach().clone())
return initial_weights, initial_biases