Skip to content

Commit

Permalink
Merge branch 'main' into mamba-generate
Browse files Browse the repository at this point in the history
  • Loading branch information
3outeille committed Apr 24, 2024
2 parents f6befa6 + b792a22 commit cb17862
Show file tree
Hide file tree
Showing 39 changed files with 2,163 additions and 260 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -162,4 +162,4 @@ cython_debug/
.vscode

checkpoints/
wandb/*
wandb/
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ We support the following:
- ZeRO-1 optimizer
- FP32 gradient accumulation
- Parameter tying/sharding
- Spectral µTransfer parametrization for scaling up neural networks

# Installation

Expand Down Expand Up @@ -111,6 +112,10 @@ Features we would like to add:
- `scripts/log_lighteval_to_wandb.py`: logs the evaluation results of LightEval to wandb, including summary statistics.


# Environment Variables
- `NANOTRON_BENCHMARK=1`: if you want to log the throughput during training


# Credits

We would like to thank everyone working on LLMs, especially those sharing their work openly from which we took great inspiration: Nvidia for `Megatron-LM/apex`, Microsoft for `DeepSpeed`, HazyResearch for `flash-attn`
11 changes: 7 additions & 4 deletions examples/config_tiny_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os

from nanotron.config import (
AdamWOptimizerArgs,
CheckpointsArgs,
Config,
DataArgs,
Expand Down Expand Up @@ -62,11 +63,13 @@
weight_decay=0.01,
clip_grad=1.0,
accumulate_grad_in_fp32=True,
adam_eps=1e-08,
adam_beta1=0.9,
adam_beta2=0.95,
torch_adam_is_fused=True,
learning_rate_scheduler=learning_rate,
optimizer_factory=AdamWOptimizerArgs(
adam_eps=1e-08,
adam_beta1=0.9,
adam_beta2=0.95,
torch_adam_is_fused=True,
),
)

parallelism = ParallelismArgs(
Expand Down
72 changes: 53 additions & 19 deletions examples/config_tiny_llama.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
checkpoints:
checkpoint_interval: 10
checkpoints_path: /fsx/nouamane/projects/nanotron/checkpoints
checkpoints_path: /fsx/ferdinandmom/ferdinand-hf/nanotron/checkpoints
checkpoints_path_is_shared_file_system: false
resume_checkpoint_path: null
save_initial_state: false
Expand Down Expand Up @@ -37,28 +37,24 @@ general:
run: tiny_llama_%date_%jobid
seed: 42
step: null
lighteval: null
logging:
iteration_step_info_interval: 1
log_level: info
log_level_replica: info
model:
ddp_bucket_cap_mb: 25
dtype: bfloat16
init_method:
std: 0.025
# use_mup: true # uncomment this and comment the std line above to use spectral µTransfer
make_vocab_size_divisible_by: 1
model_config:
bos_token_id: 1
eos_token_id: 2
hidden_act: silu
hidden_size: 16
hidden_size: 32
initializer_range: 0.02
intermediate_size: 64
intermediate_size: 128
is_llama_config: true
max_position_embeddings: 256
num_attention_heads: 4
num_hidden_layers: 2
num_hidden_layers: 10
num_key_value_heads: 4
pad_token_id: null
pretraining_tp: 1
Expand All @@ -69,30 +65,57 @@ model:
vocab_size: 256
optimizer:
accumulate_grad_in_fp32: true
adam_beta1: 0.9
adam_beta2: 0.95
adam_eps: 1.0e-08
clip_grad: 1.0
learning_rate_scheduler:
learning_rate: 0.0003
learning_rate: 0.001
lr_decay_starting_step: null
lr_decay_steps: 8
lr_decay_steps: null
lr_decay_style: cosine
lr_warmup_steps: 2
lr_warmup_steps: 2000 # 20% of the total steps
lr_warmup_style: linear
min_decay_lr: 1.0e-05
torch_adam_is_fused: true
optimizer_factory:
adam_beta1: 0.9
adam_beta2: 0.95
adam_eps: 1.0e-08
name: adamW
torch_adam_is_fused: true
weight_decay: 0.01
zero_stage: 0
parallelism:
dp: 2
expert_parallel_size: 1
pp: 2
pp: 1
pp_engine: 1f1b
tp: 2
tp_linear_async_communication: true
tp_mode: REDUCE_SCATTER
profiler: null
data_stages:
- name: Stable Training Stage
start_training_step: 1
data:
dataset:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 1
hf_dataset_config_name: null
hf_dataset_or_datasets: HuggingFaceH4/testing_alpaca_small
hf_dataset_splits: train
text_column_name: completion
num_loading_workers: 1
seed: 42
- name: Annealing Phase
start_training_step: 10
data:
dataset:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 1
hf_dataset_config_name: null
hf_dataset_or_datasets: HuggingFaceH4/testing_codealpaca_small
hf_dataset_splits: train
text_column_name: completion
num_loading_workers: 1
seed: 42
lighteval: null
tokenizer:
tokenizer_max_length: null
tokenizer_name_or_path: gpt2
Expand All @@ -103,5 +126,16 @@ tokens:
limit_val_batches: 0
micro_batch_size: 2
sequence_length: 32
train_steps: 10
train_steps: 15
val_check_interval: -1
checkpoints:
checkpoint_interval: 10
checkpoints_path: checkpoints
checkpoints_path_is_shared_file_system: false
resume_checkpoint_path: checkpoints
save_initial_state: false
profiler: null
logging:
iteration_step_info_interval: 1
log_level: info
log_level_replica: info
8 changes: 7 additions & 1 deletion examples/contributor-guide/debug_config_tiny_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
CheckpointsArgs,
Config,
DataArgs,
DatasetStageArgs,
GeneralArgs,
LlamaConfig,
LoggingArgs,
Expand Down Expand Up @@ -95,7 +96,12 @@
optimizer=optimizer,
logging=LoggingArgs(),
tokens=tokens,
data=DataArgs(dataset=dataset, seed=seed),
data_stages=[
DatasetStageArgs(
name="Stable Training Stage", start_training_step=1, data=DataArgs(dataset=dataset, seed=seed)
),
DatasetStageArgs(name="Annealing Phase", start_training_step=10, data=DataArgs(dataset=dataset, seed=seed)),
],
profiler=None,
)

Expand Down
39 changes: 25 additions & 14 deletions examples/contributor-guide/debug_config_tiny_llama.yaml
Original file line number Diff line number Diff line change
@@ -1,23 +1,34 @@
checkpoints:
checkpoint_interval: 10
checkpoints_path: /fsx/ferdinandmom/ferdinand-hf/nanotron/examples/checkpoints
checkpoints_path: /fsx/haojun/nanotron_latest/examples/checkpoints
checkpoints_path_is_shared_file_system: false
resume_checkpoint_path: null
save_initial_state: false

data_stages:
- name: General purpose training
start_training_step: 1
data:
dataset:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 1
hf_dataset_config_name: null
hf_dataset_or_datasets: HuggingFaceH4/testing_alpaca_small
hf_dataset_splits: train
text_column_name: completion
num_loading_workers: 1
seed: 42
- data:
dataset:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 1
hf_dataset_config_name: null
hf_dataset_or_datasets: HuggingFaceH4/testing_alpaca_small
hf_dataset_splits: train
text_column_name: completion
num_loading_workers: 1
seed: 42
name: Stable Training Stage
start_training_step: 1
- data:
dataset:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 1
hf_dataset_config_name: null
hf_dataset_or_datasets: HuggingFaceH4/testing_alpaca_small
hf_dataset_splits: train
text_column_name: completion
num_loading_workers: 1
seed: 42
name: Annealing Phase
start_training_step: 10
general:
benchmark_csv_path: null
consumed_train_samples: null
Expand Down
16 changes: 9 additions & 7 deletions examples/mamba/mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,15 +181,13 @@ def __init__(
self.A_log = create_sharded_parameter_from_config(
parameter=A_log, pg=self.tp_pg, split_config=SplitConfig(split_dim=0)
)
self.A_log._no_weight_decay = True

# D "skip" parameter
self.D = create_sharded_parameter_from_config(
parameter=torch.ones(self.d_inner // self.tp_pg.size(), device=device),
pg=self.tp_pg,
split_config=SplitConfig(split_dim=0),
)
self.D._no_weight_decay = True

# self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
self.out_proj = TensorParallelRowLinear(
Expand Down Expand Up @@ -806,6 +804,14 @@ def forward(
label_mask=label_mask,
)["loss"]
return {"loss": loss}

def get_named_params_without_weight_decay(self):
# get full name with "A_log", "D"
named_param_without_weight_decay = []
for name, _ in self.model.named_parameters():
if "A_log" in name or "D" in name:
named_param_without_weight_decay.append(name)
return named_param_without_weight_decay

@torch.no_grad()
def init_model_randomly(self, config):
Expand Down Expand Up @@ -916,11 +922,7 @@ def init_model_randomly(self, config):
raise ValueError(f"Who the fuck is {param_name}?")

elif isinstance(module, Mamba):
# NOTE(fmom): nn.Parameter are initialized in Mamba __init__
# In Mamba, only those 3 parameters don't have weight decay.
if param_name in ["dt_bias", "A_log", "D"]:
param._no_weight_decay = True

pass
else:
raise Exception(f"Parameter {full_param_name} was not initialized")

Expand Down
11 changes: 7 additions & 4 deletions examples/moe/config_llamoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Optional

from nanotron.config import (
AdamWOptimizerArgs,
CheckpointsArgs,
Config,
DataArgs,
Expand Down Expand Up @@ -99,11 +100,13 @@ def __post_init__(self):
weight_decay=0.01,
clip_grad=1.0,
accumulate_grad_in_fp32=False,
adam_eps=1e-08,
adam_beta1=0.9,
adam_beta2=0.95,
torch_adam_is_fused=True,
learning_rate_scheduler=learning_rate,
optimizer_factory=AdamWOptimizerArgs(
adam_eps=1e-08,
adam_beta1=0.9,
adam_beta2=0.95,
torch_adam_is_fused=True,
),
)

parallelism = ParallelismArgs(
Expand Down
46 changes: 30 additions & 16 deletions examples/moe/config_llamoe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,30 @@ checkpoints:
resume_checkpoint_path: /fsx/nouamane/projects/nanotron/examples/checkpoints
save_initial_state: true
data_stages:
- name: General purpose training
start_training_step: 1
data:
dataset:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 12
hf_dataset_config_name: null
hf_dataset_or_datasets: roneneldan/TinyStories
hf_dataset_splits: train
text_column_name: text
num_loading_workers: 1
seed: 42
- data:
dataset:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 12
hf_dataset_config_name: null
hf_dataset_or_datasets: roneneldan/TinyStories
hf_dataset_splits: train
text_column_name: text
num_loading_workers: 1
seed: 42
name: Stable Training Stage
start_training_step: 1
- data:
dataset:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 12
hf_dataset_config_name: null
hf_dataset_or_datasets: roneneldan/TinyStories
hf_dataset_splits: train
text_column_name: text
num_loading_workers: 1
seed: 42
name: Annealing Phase
start_training_step: 10
general:
benchmark_csv_path: null
consumed_train_samples: null
Expand Down Expand Up @@ -60,9 +72,6 @@ model:
vocab_size: 32000
optimizer:
accumulate_grad_in_fp32: false
adam_beta1: 0.9
adam_beta2: 0.95
adam_eps: 1.0e-08
clip_grad: 1.0
learning_rate_scheduler:
learning_rate: 0.0003
Expand All @@ -72,7 +81,12 @@ optimizer:
lr_warmup_steps: 100
lr_warmup_style: linear
min_decay_lr: 1.0e-05
torch_adam_is_fused: true
optimizer_factory:
adam_beta1: 0.9
adam_beta2: 0.95
adam_eps: 1.0e-08
name: adamW
torch_adam_is_fused: true
weight_decay: 0.01
zero_stage: 0
parallelism:
Expand Down
Loading

0 comments on commit cb17862

Please sign in to comment.