forked from mosaicml/examples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmosaic_gpt.py
509 lines (430 loc) · 19.5 KB
/
mosaic_gpt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
# Copyright 2022 MosaicML Examples authors
# SPDX-License-Identifier: Apache-2.0
"""A simple, flexible implementation of a GPT model.
Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
"""
import math
import warnings
from functools import partial
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from composer.metrics.nlp import LanguageCrossEntropy, Perplexity
from composer.models.base import ComposerModel
from omegaconf import DictConfig
class TorchCausalAttention(nn.Module):
def __init__(self, cfg: DictConfig, device: Optional[str] = None):
super().__init__()
self.mhsa = nn.MultiheadAttention(
embed_dim=cfg.d_model,
num_heads=cfg.n_heads,
dropout=cfg.attn_pdrop,
bias=True,
batch_first=True,
device=device,
)
self.mhsa.out_proj._is_residual = True # type: ignore
warnings.warn(
DeprecationWarning(
'Using `attn_impl: torch` is deprecated; recommened using `attn_impl: triton`.'
))
def forward(self, x, key_padding_mask, attn_mask=None):
return self.mhsa(x,
x,
x,
attn_mask=attn_mask,
key_padding_mask=~key_padding_mask,
need_weights=True)
@staticmethod
def mask_shape(n_heads, seq_len, alibi):
if alibi:
return (n_heads, seq_len, seq_len)
return (seq_len, seq_len)
@staticmethod
def attn_mask_(attn_mask, n_heads, seq_len, alibi=False, alibi_bias_max=8):
# in-place fill causal attn mask
#
# Two important disclaimers
# 1. Torch uses additive attention. If your attn_mask/key_padding mask is a float tensor, it will add the floats
# directly to your attention matrix. If they are boolean masks, True will be converted to -inf before adding the
# mask to your attentions. See https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html#torch.nn.MultiheadAttention.forward
# Basically True/-inf indicates tokens we do not want to attend to.
#
# 2. This is is the exact opposite behavior of Huggingface's tokenizers, which use the convention that True denotes tokens
# we do want to attend to. See https://huggingface.co/docs/transformers/glossary#attention-mask
attn_mask.fill_(float('-inf'))
attn_mask.triu_(diagonal=1)
if alibi:
device, dtype = attn_mask.device, attn_mask.dtype
a_bias = alibi_bias(n_heads,
seq_len,
full=True,
alibi_bias_max=alibi_bias_max,
device=device,
dtype=dtype)
attn_mask.add_(a_bias.squeeze())
return attn_mask
class FlashCausalAttention(nn.Module):
def __init__(self, cfg: DictConfig, device: Optional[str] = None):
super().__init__()
try:
from flash_attn.flash_attention import FlashMHA # type: ignore
except ImportError as e:
raise e
self.mhsa = FlashMHA(
embed_dim=cfg.d_model,
num_heads=cfg.n_heads,
attention_dropout=cfg.attn_pdrop,
bias=True,
batch_first=True,
causal=True,
device=device,
)
self.mhsa.out_proj._is_residual = True
warnings.warn(
DeprecationWarning(
'Using `attn_impl: flash` is deprecated; recommened using `attn_impl: triton`.'
))
def forward(self, x, key_padding_mask, attn_mask=None):
assert attn_mask is None
return self.mhsa(x,
key_padding_mask=key_padding_mask,
need_weights=False)
@staticmethod
def mask_shape(*args, **kwargs):
return None
@staticmethod
def attn_mask_(*args, **kwargs):
return None
class TritonFlashCausalAttention(nn.Module):
"""Multi-headed self attention using triton FlashAttn kernel.
This also includes bias for Alibi integration.
"""
def __init__(self, cfg: DictConfig, device: Optional[str] = None):
super().__init__()
try:
from src.flash_attention import FlashMHA # type: ignore
except ImportError as e:
raise e
assert cfg.attn_pdrop == 0, 'triton kernel does not support attn_dropout'
self.mhsa = FlashMHA(
embed_dim=cfg.d_model,
num_heads=cfg.n_heads,
bias=True,
batch_first=True,
causal=True,
device=device,
)
self.mhsa.out_proj._is_residual = True # type: ignore
def forward(self, x, key_padding_mask=None, attn_mask=None):
assert key_padding_mask is None
return self.mhsa(x,
key_padding_mask=None,
attn_mask=attn_mask,
need_weights=False)
@staticmethod
def mask_shape(n_heads, seq_len, alibi):
return (1, n_heads, 1, seq_len) if alibi else (1, 1, 1, seq_len)
@staticmethod
def attn_mask_(attn_mask, n_heads, seq_len, alibi=False, alibi_bias_max=8):
# in-place fill causal attn mask
attn_mask.zero_()
if alibi:
device, dtype = attn_mask.device, attn_mask.dtype
attn_mask.add_(
alibi_bias(n_heads,
seq_len,
full=False,
alibi_bias_max=alibi_bias_max,
device=device,
dtype=dtype))
return attn_mask
def alibi_bias(n_heads,
seq_len,
full=False,
alibi_bias_max=8,
device=None,
dtype=None):
alibi_bias = torch.arange(1 - seq_len, 1, dtype=dtype,
device=device).view(1, 1, 1, seq_len)
if full:
# generate 1 x Heads x SeqLen x SeqLen alibi bias mask
# otherwise the mask is 1 x Heads x 1 x SeqLen (which is braodcasted up to the approproate size)
alibi_bias = alibi_bias - torch.arange(
1 - seq_len, 1, dtype=dtype, device=device).view(1, 1, seq_len, 1)
alibi_bias.abs_().mul_(
-1
) # since we're using causal flag, this isn't really needed, but why not include it
m = torch.arange(1, n_heads + 1, dtype=dtype, device=device)
m.mul_(alibi_bias_max / n_heads)
alibi_bias = alibi_bias * (1. / (2**m.view(1, n_heads, 1, 1)))
return alibi_bias
class GPTMLP(nn.Module):
def __init__(self, cfg: DictConfig, device: Optional[str] = None):
super().__init__()
self.mlp_up = nn.Linear(cfg.d_model,
cfg.mlp_ratio * cfg.d_model,
device=device)
self.mlp_act = nn.GELU(approximate='none')
self.mlp_down = nn.Linear(cfg.mlp_ratio * cfg.d_model,
cfg.d_model,
device=device)
self.mlp_down._is_residual = True # type: ignore
def forward(self, x):
return self.mlp_down(self.mlp_act(self.mlp_up(x)))
class GPTBlock(nn.Module):
def __init__(self,
cfg: DictConfig,
causal_attn_cls,
device: Optional[str] = None):
super().__init__()
if cfg.get('alibi', False):
assert cfg.attn_impl == 'triton' or cfg.attn_impl == 'torch', 'Only triton kernel or torch supports alibi'
self.ln_1 = nn.LayerNorm(cfg.d_model, device=device)
self.causal_attn = causal_attn_cls(cfg, device)
self.ln_2 = nn.LayerNorm(cfg.d_model, device=device)
self.mlp = GPTMLP(cfg, device=device)
self.resid_attn_dropout = nn.Dropout(cfg.resid_pdrop)
self.resid_mlp_dropout = nn.Dropout(cfg.resid_pdrop)
def forward(
self,
x: torch.Tensor,
key_padding_mask: Optional[torch.ByteTensor] = None,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
a = self.ln_1(x)
b, _ = self.causal_attn(a, key_padding_mask, attn_mask)
x = x + self.resid_attn_dropout(b)
m = self.ln_2(x)
n = self.mlp(m)
x = x + self.resid_mlp_dropout(n)
return x
class MosaicGPT(nn.Module):
def __init__(self, cfg: DictConfig):
super().__init__()
assert cfg.name == 'mosaic_gpt', f'Tried to build MosaicGPT model with cfg.name={cfg.name}'
self.cfg = cfg
if cfg.attn_impl == 'torch':
self.causal_attn_cls = TorchCausalAttention
elif cfg.attn_impl == 'flash':
self.causal_attn_cls = FlashCausalAttention
elif cfg.attn_impl == 'triton':
self.causal_attn_cls = TritonFlashCausalAttention
else:
raise ValueError(f'Unknown attn_impl={cfg.attn_impl}')
self.alibi = cfg.get('alibi', False)
self.alibi_bias_max = cfg.get('alibi_bias_max',
8 if self.alibi else None)
# CogView (https://arxiv.org/abs/2105.13290) and GLM-130B (https://arxiv.org/abs/2210.02414)
# both report this helping with stabilizing training
self.embedding_fraction = cfg.get('embedding_fraction', 1)
assert 0 < self.embedding_fraction <= 1, 'model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!'
self.transformer = nn.ModuleDict({
'wte': nn.Embedding(cfg.vocab_size, cfg.d_model, device=cfg.device)
})
if not self.alibi:
self.transformer.update({
'wpe':
nn.Embedding(cfg.max_seq_len,
cfg.d_model,
device=cfg.device)
})
self.transformer.update({'emb_drop': nn.Dropout(cfg.emb_pdrop)})
self.transformer.update({
'blocks':
nn.ModuleList([
GPTBlock(cfg,
causal_attn_cls=self.causal_attn_cls,
device=cfg.device) for _ in range(cfg.n_layers)
])
})
self.transformer.update(
{'ln_f': nn.LayerNorm(cfg.d_model, device=cfg.device)})
if cfg.device != 'meta':
self.apply(self.param_init_fn)
# define attn mask
self._attn_mask_initialized = False
mask_shape = self.causal_attn_cls.mask_shape(cfg.n_heads,
cfg.max_seq_len,
self.alibi)
if mask_shape:
self.register_buffer('attn_mask',
torch.empty(mask_shape, device=cfg.device))
else:
self.attn_mask = None
def _attn_mask(self, batch_size=None, seq_len=None, key_padding_mask=None):
if not self._attn_mask_initialized:
self.causal_attn_cls.attn_mask_(self.attn_mask,
self.cfg.n_heads,
self.cfg.max_seq_len,
alibi=self.alibi,
alibi_bias_max=self.alibi_bias_max)
self._attn_mask_initialized = True
if self.cfg.attn_impl == 'flash':
return self.attn_mask # None
# select seq_len subset of attn mask
assert self.attn_mask is not None, 'Internal logic error'
attn_mask = self.attn_mask[..., :seq_len, :seq_len]
if self.cfg.attn_impl == 'triton' and key_padding_mask is not None and key_padding_mask.bool(
).logical_not().any():
attn_mask = attn_mask.masked_fill(
~key_padding_mask.view(batch_size, 1, 1, seq_len),
float('-inf'))
if self.cfg.attn_impl == 'torch':
if key_padding_mask is not None and key_padding_mask.bool(
).logical_not().any():
attn_mask = attn_mask.expand(batch_size, self.cfg.n_heads,
seq_len, seq_len).clone()
attn_mask.masked_fill_(
~key_padding_mask.view(batch_size, 1, 1, seq_len),
float('-inf'))
attn_mask = attn_mask.reshape(-1, seq_len, seq_len)
elif self.alibi:
# WARNING: Alibi with torch attn is not thoroughly tested
# torch mask is supposed to be of shape nzz x SeqLen x SeqLen
# we must braodcast to batch size then flatten batchsize * n_heads dim
# Note: if key_padding_mask is triggered, the needed expansion is already done.
attn_mask = attn_mask.expand(batch_size, self.cfg.n_heads,
seq_len, seq_len).reshape(
-1, seq_len, seq_len)
return attn_mask
def forward(self,
input_ids: torch.LongTensor,
key_padding_mask: Optional[torch.ByteTensor] = None):
B, S = input_ids.size()
assert (
S <= self.cfg.max_seq_len
), f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.cfg.max_seq_len}'
tok_emb = self.transformer.wte(input_ids) # type: ignore
if self.alibi:
x = tok_emb
else:
pos = torch.arange(0, S, dtype=torch.long,
device=input_ids.device).unsqueeze(0)
pos_emb = self.transformer.wpe(pos) # type: ignore
x = tok_emb + pos_emb
if self.embedding_fraction == 1:
x = self.transformer.emb_drop(x) # type: ignore
else:
# this implementation is proposed on page 7 of the GLM-130B paper https://arxiv.org/abs/2210.02414
x_shrunk = (x * self.embedding_fraction) + (
x.detach() * (1 - self.embedding_fraction))
assert isinstance(self.transformer.emb_drop, nn.Module) # pyright
x = self.transformer.emb_drop(x_shrunk)
attn_mask = self._attn_mask(batch_size=B,
seq_len=S,
key_padding_mask=key_padding_mask)
for block in self.transformer.blocks: # type: ignore
x = block(
x, None if self.cfg.attn_impl == 'triton' else key_padding_mask,
attn_mask)
x = self.transformer.ln_f(x) # type: ignore
# output embedding weight tied to input embedding
assert isinstance(self.transformer.wte, nn.Module) # pyright
assert isinstance(self.transformer.wte.weight, torch.Tensor) # pyright
logits = F.linear(x, self.transformer.wte.weight, None)
return logits
# Param Initialization, needed for device='meta' fast initialization
def param_init_fn(self, module):
init_fn = partial(torch.nn.init.normal_,
mean=0.0,
std=self.cfg.init_std)
# Linear
if isinstance(module, nn.Linear):
init_fn(module.weight)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
if getattr(module, '_is_residual', False):
module.weight.data.normal_(
mean=0.0,
std=(self.cfg.init_std / math.sqrt(2 * self.cfg.n_layers)))
# Embedding
if isinstance(module, nn.Embedding):
init_fn(module.weight)
# LayerNorm
if isinstance(module, nn.LayerNorm):
torch.nn.init.zeros_(module.bias)
torch.nn.init.ones_(module.weight)
# torch's MultiheadAttention
if isinstance(module, nn.MultiheadAttention):
if module._qkv_same_embed_dim:
assert module.in_proj_weight is not None
assert module.q_proj_weight is None and module.k_proj_weight is None and module.v_proj_weight is None
init_fn(module.in_proj_weight)
else:
assert module.q_proj_weight is not None and module.k_proj_weight is not None and module.v_proj_weight is not None
assert module.in_proj_weight is None
init_fn(module.q_proj_weight)
init_fn(module.k_proj_weight)
init_fn(module.v_proj_weight)
# bias
if module.in_proj_bias is not None:
torch.nn.init.zeros_(module.in_proj_bias)
if module.bias_k is not None:
torch.nn.init.zeros_(module.bias_k)
if module.bias_v is not None:
torch.nn.init.zeros_(module.bias_v)
# out proj
if module.out_proj._is_residual:
module.out_proj.weight.data.normal_(
mean=0.0,
std=(self.cfg.init_std / math.sqrt(2 * self.cfg.n_layers)))
else:
init_fn(module.out_proj.weight)
if module.out_proj.bias is not None:
torch.nn.init.zeros_(module.out_proj.bias)
# FSDP Wrap function
def fsdp_wrap_fn(self, module):
return isinstance(module, GPTBlock)
# Activation Checkpointing
def activation_checkpointing_fn(self, module):
return isinstance(module, GPTBlock)
class ComposerMosaicGPT(ComposerModel):
def __init__(self, cfg):
super().__init__()
self.model = MosaicGPT(cfg)
self.__num_fwd_flops = None
self.train_metrics = {
'LanguageCrossEntropy': LanguageCrossEntropy(cfg.vocab_size),
'Perplexity': Perplexity(),
}
self.eval_metrics = {
'LanguageCrossEntropy': LanguageCrossEntropy(cfg.vocab_size),
'Perplexity': Perplexity(),
}
def get_targets(self, batch):
targets = torch.roll(batch['labels'], shifts=-1)
targets[:, -1] = -100
return targets
def forward(self, batch):
return self.model(batch['input_ids'],
key_padding_mask=batch['attention_mask'].bool())
def eval_forward(self, batch, outputs=None):
return outputs if outputs is not None else self.forward(batch)
def loss(self, outputs, batch):
targets = self.get_targets(batch)
return F.cross_entropy(outputs.view(-1, outputs.size(-1)),
targets.view(-1),
ignore_index=-100)
def get_metrics(self, is_train=False):
return self.train_metrics if is_train else self.eval_metrics
def update_metric(self, batch, outputs, metric):
outputs = outputs.view(-1, outputs.size(-1))
targets = self.get_targets(batch).view(-1)
metric.update(outputs, targets)
@property
def num_fwd_flops(self):
if self.__num_fwd_flops:
return self.__num_fwd_flops
n_params = sum(p.numel() for p in self.parameters())
# the number of paramters is approximately the number of multiply-accumulates (MAC) in the network
# each MAC has 2 FLOPs - we multiply by 2 ie 2 * n_param
# this gets us FLOPs / token
params_flops_per_token = 2 * n_params
params_flops_per_seq = params_flops_per_token * self.model.cfg.max_seq_len
# there are 2 FLOPS per mac; there is A=Q*K^T and out=A*V ops (ie mult by 2)
attn_flops_per_seq = self.model.cfg.n_layers * 2 * 2 * (
self.model.cfg.d_model * (self.model.cfg.max_seq_len**2))
self.__num_fwd_flops = params_flops_per_seq + attn_flops_per_seq
return self.__num_fwd_flops