-
Notifications
You must be signed in to change notification settings - Fork 123
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
215 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from nanotron.mod.mod import MixtureOfDepth, Router |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
from typing import Dict, Optional, Union, List | ||
|
||
import torch | ||
from torch import nn | ||
|
||
import torch.distributed as dist | ||
from nanotron.config import LlamaConfig, ParallelismArgs | ||
from nanotron.nn.layer_norm import TritonRMSNorm | ||
from nanotron.parallel import ParallelContext | ||
from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer | ||
from nanotron.parallel.pipeline_parallel.p2p import P2P | ||
from nanotron.parallel.tensor_parallel.nn import ( | ||
TensorParallelColumnLinear, | ||
TensorParallelLinearMode, | ||
) | ||
from nanotron.models.llama import LlamaModel, Embedding, LlamaDecoderLayer, CausalSelfAttention, MLP | ||
from nanotron.mod.mod import MixtureOfDepth, Router | ||
|
||
|
||
# class LlamaDecoderLayer(nn.Module): | ||
# def __init__( | ||
# self, | ||
# config: LlamaConfig, | ||
# parallel_config: Optional[ParallelismArgs], | ||
# tp_pg: dist.ProcessGroup, | ||
# layer_idx: int, | ||
# ): | ||
# super().__init__() | ||
# self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | ||
# self.attn = CausalSelfAttention( | ||
# config=config, | ||
# parallel_config=parallel_config, | ||
# tp_pg=tp_pg, | ||
# layer_idx=layer_idx, | ||
# ) | ||
|
||
# self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | ||
# self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg) | ||
# self.router = Router(seq_len=1024, top_k=10) | ||
|
||
# def forward( | ||
# self, | ||
# hidden_states: Union[torch.Tensor, TensorPointer], | ||
# sequence_mask: Union[torch.Tensor, TensorPointer], | ||
# ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: | ||
# residual = hidden_states | ||
# hidden_states = self.input_layernorm(hidden_states) | ||
|
||
# output = self.attn(hidden_states=hidden_states, sequence_mask=sequence_mask) | ||
# hidden_states = output["hidden_states"] | ||
# hidden_states = hidden_states + residual | ||
|
||
# residual = hidden_states | ||
# hidden_states = self.post_attention_layernorm(hidden_states) | ||
# hidden_states = self.mlp(hidden_states=hidden_states)["hidden_states"] | ||
# hidden_states = hidden_states + residual | ||
|
||
# return { | ||
# "hidden_states": hidden_states, | ||
# "sequence_mask": output["sequence_mask"], | ||
# } | ||
|
||
|
||
class MoDLlamaModel(nn.Module, LlamaModel): | ||
"""Build pipeline graph""" | ||
|
||
def __init__( | ||
self, | ||
config: LlamaConfig, | ||
parallel_context: ParallelContext, | ||
parallel_config: Optional[ParallelismArgs], | ||
): | ||
super().__init__() | ||
|
||
# Declare all the nodes | ||
self.p2p = P2P(parallel_context.pp_pg, device=torch.device("cuda")) | ||
self.config = config | ||
self.parallel_config = parallel_config | ||
self.parallel_context = parallel_context | ||
self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE | ||
tp_linear_async_communication = ( | ||
parallel_config.tp_linear_async_communication if parallel_config is not None else False | ||
) | ||
|
||
self.token_position_embeddings = PipelineBlock( | ||
p2p=self.p2p, | ||
module_builder=Embedding, | ||
module_kwargs={ | ||
"tp_pg": parallel_context.tp_pg, | ||
"config": config, | ||
"parallel_config": parallel_config, | ||
}, | ||
module_input_keys={"input_ids", "input_mask"}, | ||
module_output_keys={"input_embeds"}, | ||
) | ||
|
||
self.decoder = nn.ModuleList( | ||
[ | ||
PipelineBlock( | ||
p2p=self.p2p, | ||
module_builder=LlamaDecoderLayer, | ||
module_kwargs={ | ||
"config": config, | ||
"parallel_config": parallel_config, | ||
"tp_pg": parallel_context.tp_pg, | ||
"layer_idx": layer_idx, | ||
}, | ||
module_input_keys={"hidden_states", "sequence_mask"}, | ||
module_output_keys={"hidden_states", "sequence_mask"}, | ||
) | ||
for layer_idx in range(config.num_hidden_layers) | ||
] | ||
) | ||
|
||
self.final_layer_norm = PipelineBlock( | ||
p2p=self.p2p, | ||
module_builder=TritonRMSNorm, | ||
module_kwargs={"hidden_size": config.hidden_size, "eps": config.rms_norm_eps}, | ||
module_input_keys={"input"}, | ||
module_output_keys={"hidden_states"}, | ||
) | ||
|
||
self.lm_head = PipelineBlock( | ||
p2p=self.p2p, | ||
# Understand that this means that we return sharded logits that are going to need to be gathered | ||
module_builder=TensorParallelColumnLinear, | ||
module_kwargs={ | ||
"in_features": config.hidden_size, | ||
"out_features": config.vocab_size, | ||
"pg": parallel_context.tp_pg, | ||
"bias": False, | ||
# TODO @thomasw21: refactor so that we store that default in a single place. | ||
"mode": self.tp_mode, | ||
"async_communication": tp_linear_async_communication, | ||
}, | ||
module_input_keys={"x"}, | ||
module_output_keys={"logits"}, | ||
) | ||
|
||
self.cast_to_fp32 = PipelineBlock( | ||
p2p=self.p2p, | ||
module_builder=lambda: lambda x: x.float(), | ||
module_kwargs={}, | ||
module_input_keys={"x"}, | ||
module_output_keys={"output"}, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import torch | ||
from torch import nn | ||
from torchtyping import TensorType | ||
import torch.nn.functional as F | ||
|
||
class MixtureOfDepth(nn.Module): | ||
def __init__(self, capacity: int, d_model: int, block: nn.Module): | ||
super().__init__() | ||
self.router = Router(capacity, d_model) | ||
self.block = block | ||
|
||
def forward(self, inputs: TensorType["batch_size", "seq_len", "d_model"]) -> TensorType["batch_size", "seq_len", "d_model"]: | ||
selected_idxs = self.router(inputs) | ||
assert selected_idxs.shape == (inputs.size(0), self.router.capacity) | ||
selected_inputs = inputs[torch.arange(inputs.size(0)).unsqueeze(1), selected_idxs] | ||
|
||
outputs_of_selected_inputs = self.block(selected_inputs) | ||
# NOTE: now keep the representation of the selected inputs and replace the original inputs with the new ones | ||
inputs[torch.arange(inputs.size(0)).unsqueeze(1), selected_idxs] = outputs_of_selected_inputs | ||
return inputs | ||
|
||
|
||
class Router(nn.Module): | ||
def __init__(self, capacity: int, d_model: int): | ||
super().__init__() | ||
self.capacity = capacity | ||
self.gate = nn.Linear(d_model, 1) | ||
|
||
def forward(self, inputs: TensorType["batch_size", "seq_len", "d_model"]) -> TensorType["batch_size", "seq_len"]: | ||
probs = F.softmax(self.gate(inputs), dim=1).view(-1, inputs.size(1)) | ||
_, top_k_indices = torch.topk(probs, self.capacity) | ||
return top_k_indices |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import torch | ||
from torch import nn | ||
import pytest | ||
|
||
from nanotron.mod import MixtureOfDepth, Router | ||
|
||
|
||
@pytest.mark.parametrize("seq_len, top_k", [(1, 1), (10, 5), (10, 10)]) | ||
def test_mod(seq_len, top_k): | ||
BATCH_SIZE = 15 | ||
D_MODEL = 1024 | ||
|
||
linear = nn.Linear(D_MODEL, D_MODEL) | ||
block = MixtureOfDepth(top_k, D_MODEL, linear) | ||
|
||
inputs = torch.randn(BATCH_SIZE, seq_len, D_MODEL) | ||
ref_inputs = inputs.clone() | ||
outputs = block(inputs) | ||
|
||
expected_num_tokens_not_changed = (seq_len - top_k) * BATCH_SIZE | ||
num_tokens_not_changed = torch.eq(outputs.view(-1, D_MODEL), ref_inputs.view(-1, D_MODEL)).all(dim=1).sum().item() | ||
|
||
assert outputs.shape == linear(ref_inputs).shape | ||
assert num_tokens_not_changed == expected_num_tokens_not_changed, f"num_tokens_not_changed: {num_tokens_not_changed}, expected: {expected_num_tokens_not_changed}" | ||
|
||
|
||
@pytest.mark.parametrize("capacity, d_model", [(1, 64), (10, 64)]) | ||
def test_router(capacity, d_model): | ||
BATCH_SIZE, SEQ_LEN = 5, 10 | ||
inputs = torch.randn(BATCH_SIZE, SEQ_LEN, d_model) | ||
|
||
router = Router(capacity, d_model) | ||
selected_idxs = router(inputs) | ||
|
||
assert selected_idxs.shape == (BATCH_SIZE, capacity) | ||
assert selected_idxs.dtype == torch.int64 |