-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathmamba_block.py
executable file
·588 lines (495 loc) · 24.2 KB
/
mamba_block.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
import math
from typing import Optional, Union
import re
from contextlib import nullcontext
from abc import ABC, abstractmethod
from dataclasses import dataclass
import functools
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from einops import rearrange, repeat
try:
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
except ImportError:
causal_conv1d_fn, causal_conv1d_update = None, None
try:
from ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
except ImportError:
selective_scan_fn, mamba_inner_fn = None, None
try:
from ops.triton.selective_state_update import selective_state_update
except ImportError:
selective_state_update = None
try:
from ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
except ImportError:
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
from mamba_layer import MambaLayer
from mamba2_layer import Mamba2Layer
from mamba_config import MambaConfig
from mlp import MLP
from attention import CausalSelfAttention
from rotary import RotaryEmbedding
class MambaBlock(nn.Module):
def __init__(
self, config, mixer_cls, moe_cls=None, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
):
super().__init__()
self.config = config
self.residual_in_fp32 = residual_in_fp32
self.fused_add_norm = fused_add_norm
self.mixer = mixer_cls(config)
if config.use_module_layernorm and not config.rms_norm:
self.norm = norm_cls
else:
self.norm = norm_cls(config.hidden_size)
if self.fused_add_norm:
assert RMSNorm is not None, "RMSNorm import fails"
assert isinstance(
self.norm, (nn.LayerNorm, RMSNorm)
), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
assert config.num_mem_heads == 0, 'args.num_mem_heads > 0 only supports fused_add_norm=False'
if moe_cls is not None:
self.moe = moe_cls(config)
else:
self.moe = None
def forward(
self, hidden_states: Tensor, from_shared_proj: Optional[Tensor] = None, from_tf: Optional[Tensor] = None, residual: Optional[Tensor] = None, inference_params=None, attention_mask=None
):
if not self.fused_add_norm:
residual = (hidden_states + residual) if residual is not None else hidden_states
if from_tf is not None:
hidden_states = self.norm((residual + from_tf).to(dtype=self.norm.weight.dtype))
else:
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
if self.residual_in_fp32:
residual = residual.to(torch.float32)
else:
fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
hidden_states, residual = fused_add_norm_fn(
hidden_states,
self.norm.weight,
self.norm.bias,
residual=residual,
prenorm=True,
residual_in_fp32=self.residual_in_fp32,
eps=self.norm.eps,
)
hidden_states = self.mixer(hidden_states, from_shared_proj=from_shared_proj, inference_params=inference_params)
return hidden_states , residual
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
class AttentionBlock(nn.Module):
def __init__(
self, config, mixer_cls, moe_cls=None, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
):
super().__init__()
self.config = config
self.residual_in_fp32 = residual_in_fp32
self.fused_add_norm = fused_add_norm
self.mixer = mixer_cls(config)
if config.use_module_layernorm and not config.rms_norm:
self.norm = norm_cls
else:
self.norm = norm_cls(config.hidden_size)
if self.fused_add_norm:
assert RMSNorm is not None, "RMSNorm import fails"
assert isinstance(
self.norm, (nn.LayerNorm, RMSNorm)
), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
assert config.num_mem_heads == 0, 'args.num_mem_heads > 0 only supports fused_add_norm=False'
if moe_cls is not None:
self.moe = moe_cls(config)
else:
self.moe = None
self.rotary_pos_emb = RotaryEmbedding(
config.kv_channels, rotary_percent=1.0, seq_len_interpolation_factor=None
)
def forward(
self, hidden_states: Tensor, from_tf: Optional[Tensor] = None, residual: Optional[Tensor] = None, inference_params=None, attention_mask=None
):
if not self.fused_add_norm:
residual = (hidden_states + residual) if residual is not None else hidden_states
if from_tf is not None:
hidden_states = self.norm((residual + from_tf).to(dtype=self.norm.weight.dtype))
else:
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
if self.residual_in_fp32:
residual = residual.to(torch.float32)
else:
fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
hidden_states, residual = fused_add_norm_fn(
hidden_states,
self.norm.weight,
self.norm.bias,
residual=residual,
prenorm=True,
residual_in_fp32=self.residual_in_fp32,
eps=self.norm.eps,
)
hidden_states = hidden_states.transpose(0,1).contiguous()
rotary_seq_len = hidden_states.shape[0]
rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len)
hidden_states = self.mixer(hidden_states, rotary_pos_emb=rotary_pos_emb, attention_mask=attention_mask, inference_params=inference_params)
hidden_states = hidden_states.transpose(0,1)
return hidden_states , residual
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
class Memory_AttentionBlock(nn.Module):
def __init__(
self, config, mixer_cls, moe_cls=None, norm_cls=nn.LayerNorm, residual_in_fp32=False, fused_add_norm=False
):
super().__init__()
self.config = config
self.residual_in_fp32 = residual_in_fp32
self.mixer = mixer_cls(config)
assert config.rms_norm, 'Memory_AttentionBlock only supports RMSNorm'
self.norm = norm_cls(2 * config.hidden_size)
self.fused_add_norm = fused_add_norm
if moe_cls is not None:
self.moe = moe_cls(config)
else:
self.moe = None
def forward(
self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None, attention_mask=None, rotary_pos_emb=None, forward_layer_idx = None
):
if not self.fused_add_norm:
residual = (hidden_states + residual) if residual is not None else hidden_states
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
if self.residual_in_fp32:
residual = residual.to(torch.float32)
else:
fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
hidden_states, residual = fused_add_norm_fn(
hidden_states,
self.norm.weight,
self.norm.bias,
residual=residual,
prenorm=True,
residual_in_fp32=self.residual_in_fp32,
eps=self.norm.eps,
)
hidden_states = hidden_states.transpose(0,1).contiguous()
hidden_states = self.mixer(hidden_states, rotary_pos_emb=rotary_pos_emb, attention_mask=attention_mask, inference_params=inference_params, forward_layer_idx = forward_layer_idx)
hidden_states = hidden_states.transpose(0,1)
return hidden_states
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
class MambaBlockParallelMoe(nn.Module):
def __init__(
self, config, mixer_cls, moe_cls=None, norm_cls=nn.LayerNorm, norm_moe=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
):
super().__init__()
self.config = config
self.residual_in_fp32 = residual_in_fp32
self.fused_add_norm = fused_add_norm
self.mixer = mixer_cls(config)
if config.use_module_layernorm and not config.rms_norm:
self.norm = norm_cls
self.norm_moe = norm_moe
else:
self.norm = norm_cls(config.hidden_size)
self.norm_moe = norm_moe(config.hidden_size)
if self.fused_add_norm:
assert RMSNorm is not None, "RMSNorm import fails"
assert isinstance(
self.norm, (nn.LayerNorm, RMSNorm)
), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
assert isinstance(
self.norm_moe, (nn.LayerNorm, RMSNorm)
), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
if moe_cls is not None:
self.moe = moe_cls(config)
else:
self.moe = None
def forward(
self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None, attention_mask=None
):
if not self.fused_add_norm:
residual = (hidden_states + residual) if residual is not None else hidden_states
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
hidden_states_moe = self.norm_moe(residual.to(dtype=self.norm.weight.dtype))
if self.residual_in_fp32:
residual = residual.to(torch.float32)
else:
fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
hidden_states, residual = fused_add_norm_fn(
hidden_states,
self.norm.weight,
self.norm.bias,
residual=residual,
prenorm=True,
residual_in_fp32=self.residual_in_fp32,
eps=self.norm.eps,
)
hidden_states_moe, _ = fused_add_norm_fn(
hidden_states,
self.norm_moe.weight,
self.norm_moe.bias,
residual=residual,
prenorm=True,
residual_in_fp32=self.residual_in_fp32,
eps=self.norm_moe.eps,
)
hidden_states = self.mixer(hidden_states, inference_params=inference_params)
hidden_states_moe = self.moe(hidden_states_moe)
hidden_states += hidden_states_moe
return hidden_states , residual
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
class MoEBlock(nn.Module):
def __init__(
self, config, mixer_cls, moe_cls=None, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False, attention_mask=None, layer_idx=None
):
super().__init__()
self.config = config
self.residual_in_fp32 = residual_in_fp32
self.fused_add_norm = fused_add_norm
self.mixer = mixer_cls(config, layer_idx=layer_idx)
if config.use_module_layernorm and not config.rms_norm:
self.norm = norm_cls
else:
self.norm = norm_cls(config.hidden_size)
if self.fused_add_norm:
assert RMSNorm is not None, "RMSNorm import fails"
assert isinstance(
self.norm, (nn.LayerNorm, RMSNorm)
), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
if moe_cls is not None:
self.moe = moe_cls(config)
else:
self.moe = None
def forward(
self, hidden_states: Tensor, residual: Optional[Tensor] = None, from_tf: Optional[Tensor] = None, inference_params=None, attention_mask=None, forward_layer_idx = None
):
if not self.fused_add_norm:
residual = (hidden_states + residual) if residual is not None else hidden_states
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
if self.residual_in_fp32:
residual = residual.to(torch.float32)
else:
fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
hidden_states, residual = fused_add_norm_fn(
hidden_states,
self.norm.weight,
self.norm.bias,
residual=residual,
prenorm=True,
residual_in_fp32=self.residual_in_fp32,
eps=self.norm.eps,
)
hidden_states = self.mixer(hidden_states, forward_layer_idx = forward_layer_idx)
return hidden_states , residual
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
class vBlock(nn.Module):
def __init__(
self, config, sa_cls, mlp_cls=None, norm_cls=nn.LayerNorm, residual_in_fp32=False, layer_idx=None
):
super().__init__()
self.use_mem_mlp = config.use_mem_mlp
self.sa = Memory_AttentionBlock(config, mixer_cls=sa_cls, norm_cls=norm_cls, residual_in_fp32=config.residual_in_fp32)
if config.use_mem_mlp:
self.mlp = MoEBlock(config, mixer_cls=mlp_cls, norm_cls=norm_cls, residual_in_fp32=config.residual_in_fp32, layer_idx=-1)
def forward(self, hidden_states, residual=None, x_orig=None, inference_params=None, attention_mask=None, rotary_pos_emb=None, forward_layer_idx = None):
x = hidden_states + residual if residual is not None else hidden_states
x_ = torch.concatenate([x, x_orig], dim=-1).type(hidden_states.dtype)
x = self.sa(x_, inference_params=inference_params, attention_mask=attention_mask, rotary_pos_emb=rotary_pos_emb, forward_layer_idx = forward_layer_idx)
if self.use_mem_mlp:
x, residual = self.mlp(x, forward_layer_idx = forward_layer_idx)
return x
def count_mem_blocks_in_config(config):
num_gs = 0
for val in config.layer_mapping:
if val == 'g':
num_gs +=1
return num_gs
def create_block(config, layer_idx):
factory_kwargs = {}
if layer_idx == -1:
num_gs = count_mem_blocks_in_config(config)
norm_cls = partial(RMSNorm, eps=config.layernorm_epsilon, dtype=torch.float32)
sa_cls = partial(CausalSelfAttention, **factory_kwargs, layer_number=-1, num_gs=num_gs)
mlp_cls = partial(MLP, layer_idx=layer_idx, **factory_kwargs, num_gs = num_gs)
block = vBlock(
config,
sa_cls=sa_cls,
mlp_cls=mlp_cls,
norm_cls=norm_cls,
residual_in_fp32=config.residual_in_fp32,
layer_idx=layer_idx
)
else:
norm_cls = partial(nn.LayerNorm if not config.rms_norm else RMSNorm, eps=config.layernorm_epsilon)
if (not config.layer_mapping) or config.layer_mapping[layer_idx-1][0] == 'r' or config.layer_mapping[layer_idx-1][0] == 'g':
if (not config.layer_mapping) or len(config.layer_mapping[layer_idx-1]) == 1:
if 'm' in config.layer_mapping:
mixer_cls = partial(Mamba2Layer, layer_idx=layer_idx, **factory_kwargs)
else:
mixer_cls = partial(MambaLayer, layer_idx=layer_idx, **factory_kwargs)
block = MambaBlock(
config,
mixer_cls=mixer_cls,
norm_cls=norm_cls,
fused_add_norm=config.fused_add_norm,
residual_in_fp32=config.residual_in_fp32,
)
elif config.layer_mapping[layer_idx-1][0] == 'm':
mixer_cls = partial(Mamba2Layer, layer_idx=layer_idx, **factory_kwargs)
block = MambaBlock(
config,
mixer_cls=mixer_cls,
norm_cls=norm_cls,
fused_add_norm=config.fused_add_norm,
residual_in_fp32=config.residual_in_fp32,
)
return block
class MambaDecoder(nn.Module):
def __init__(
self,
config: MambaConfig,
post_layer_norm=True,
pre_process=True,
post_process=True,
):
super().__init__()
self.config: MambaConfig = config
self.post_layer_norm = post_layer_norm
self.pre_process = pre_process
self.post_process = post_process
self.input_tensor = None
self.checkpoint_core_block = self.config.recompute_granularity == 'selective'
self.num_layers_per_pipeline_rank = (
self.config.num_layers
)
self.layer_mapping = config.layer_mapping
self.use_mem_rope = config.use_mem_rope
self._build_layers()
def _build_layers(self):
num_layers_to_build = self.num_layers_per_pipeline_rank
self.layers = torch.nn.ModuleList([create_block(self.config, i + 1) for i in range(num_layers_to_build)])
if self.config.num_mem_heads > 0:
blocks = []
for _ in range(self.config.num_mem_blocks):
blocks.append(create_block(self.config, layer_idx=-1))
self.blocks = torch.nn.ModuleList(blocks)
self.block_map = torch.nn.ModuleList([
nn.Linear(self.config.hidden_size, self.config.hidden_size, bias = self.config.add_bias_linear) if (i%2 == 1 if (self.layer_mapping is None) else self.layer_mapping[i] == 'g') else nn.Identity() for i in range(self.config.num_layers)])
if self.use_mem_rope:
self.rotary_pos_emb = RotaryEmbedding(
2 * self.config.hidden_size // self.config.num_mem_heads, rotary_percent=1.0, seq_len_interpolation_factor=None
)
if self.config.use_low_rank_mamba_proj:
blocks = []
d_inner = self.config.expansion_factor * self.config.hidden_size
nheads = d_inner // self.config.mamba_headdim
d_in_proj = 2 * d_inner + 2 * self.config.mamba_ngroups * self.config.state_size + nheads
for _ in range(self.config.num_shared_mamba_proj):
blocks.append(nn.Linear(self.config.hidden_size, d_in_proj, bias = self.config.add_bias_linear))
self.in_projs = torch.nn.ModuleList(blocks)
if self.post_process and self.post_layer_norm:
self.final_layernorm = RMSNorm(self.config.hidden_size, eps=self.config.layernorm_epsilon, dtype=torch.float32)
def _get_layer(self, layer_number):
return self.layers[layer_number]
def _checkpointed_forward(self, hidden_states, residual, inference_params):
"""Forward method with activation checkpointing."""
def custom(start, end):
def custom_forward(*args, **kwargs):
x, residual, *args = args
for index in range(start, end):
layer = self._get_layer(index)
x, residual = layer(x, residual, *args, **kwargs)
return x, residual
return custom_forward
if self.config.recompute_method == 'uniform':
l = 0
while l < self.num_layers_per_pipeline_rank:
hidden_states,residual = tensor_parallel.checkpoint(
custom(l, l + self.config.recompute_num_layers),
self.config.distribute_saved_activations,
hidden_states,
residual,
inference_params,
)
l += self.config.recompute_num_layers
elif self.config.recompute_method == 'block':
for l in range(self.num_layers_per_pipeline_rank):
if l < self.config.recompute_num_layers:
hidden_states, residual = tensor_parallel.checkpoint(
custom(l, l + 1),
self.config.distribute_saved_activations,
hidden_states,
residual,
inference_params,
)
else:
hidden_states = custom(l, l + 1)(hidden_states, residual, inference_params)
else:
raise ValueError("Invalid activation recompute method.")
return hidden_states, residual
def set_input_tensor(self, input_tensor):
"""Set input tensor to be used instead of forward()'s input.
When doing pipeline parallelism the input from the previous
stage comes from communication, not from the input, so the
model's forward_step_func won't have it. This function is thus
used by internal code to bypass the input provided by the
forward_step_func"""
self.input_tensor = input_tensor
def forward(self, hidden_states, residual = None, inference_params=None, attention_mask=None):
if not self.pre_process:
hidden_states = self.input_tensor
rng_context = nullcontext()
fp8_context = nullcontext()
with rng_context and fp8_context:
residual = None
x_orig = torch.clone(hidden_states)
from_tf = None
block_count = 0
rotary_pos_emb=None
if self.use_mem_rope:
if inference_params is not None and inference_params.sequence_len_offset > 0:
rotary_seq_len = inference_params.max_sequence_length
else:
rotary_seq_len = hidden_states.shape[1]
rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len)
for i, layer in enumerate(self.layers):
if self.config.num_mem_heads > 0:
if (i%2 == 1 if (self.layer_mapping is None) else self.layer_mapping[i] == 'g'):
from_tf = self.block_map[i](
self.blocks[block_count % self.config.num_mem_blocks](
hidden_states, residual, x_orig, inference_params=inference_params, attention_mask = attention_mask, rotary_pos_emb=rotary_pos_emb, forward_layer_idx=block_count
)
)
block_count += 1
else:
from_tf, _ = (None, None)
from_shared_proj = None
if self.config.use_low_rank_mamba_proj:
from_shared_proj = self.in_projs[i % self.config.num_shared_mamba_proj](hidden_states)
hidden_states, residual = layer(
hidden_states=hidden_states,
from_shared_proj=from_shared_proj,
from_tf=from_tf,
residual = residual,
inference_params=inference_params,
attention_mask = attention_mask,
)
if self.post_process and self.post_layer_norm:
if not self.config.fused_add_norm:
residual = (hidden_states + residual) if residual is not None else hidden_states
hidden_states = self.final_layernorm(residual.to(dtype=self.final_layernorm.weight.dtype))
else:
fused_add_norm_fn = rms_norm_fn if isinstance(self.final_layernorm, RMSNorm) else layer_norm_fn
hidden_states = fused_add_norm_fn(
hidden_states,
self.final_layernorm.weight,
self.final_layernorm.bias,
eps=self.final_layernorm.eps,
residual=residual,
prenorm=False,
residual_in_fp32=self.residual_in_fp32,
)
return hidden_states