diff --git a/.gitignore b/.gitignore index 29c2de5d..cbc04eaf 100644 --- a/.gitignore +++ b/.gitignore @@ -162,4 +162,4 @@ cython_debug/ .vscode checkpoints/ -wandb/* \ No newline at end of file +wandb/ diff --git a/README.md b/README.md index f910a994..b5748d60 100644 --- a/README.md +++ b/README.md @@ -44,6 +44,7 @@ We support the following: - ZeRO-1 optimizer - FP32 gradient accumulation - Parameter tying/sharding + - Spectral µTransfer parametrization for scaling up neural networks # Installation @@ -111,6 +112,10 @@ Features we would like to add: - `scripts/log_lighteval_to_wandb.py`: logs the evaluation results of LightEval to wandb, including summary statistics. +# Environment Variables +- `NANOTRON_BENCHMARK=1`: if you want to log the throughput during training + + # Credits We would like to thank everyone working on LLMs, especially those sharing their work openly from which we took great inspiration: Nvidia for `Megatron-LM/apex`, Microsoft for `DeepSpeed`, HazyResearch for `flash-attn` diff --git a/examples/config_tiny_llama.py b/examples/config_tiny_llama.py index 31431956..dfbee136 100644 --- a/examples/config_tiny_llama.py +++ b/examples/config_tiny_llama.py @@ -2,6 +2,7 @@ import os from nanotron.config import ( + AdamWOptimizerArgs, CheckpointsArgs, Config, DataArgs, @@ -62,11 +63,13 @@ weight_decay=0.01, clip_grad=1.0, accumulate_grad_in_fp32=True, - adam_eps=1e-08, - adam_beta1=0.9, - adam_beta2=0.95, - torch_adam_is_fused=True, learning_rate_scheduler=learning_rate, + optimizer_factory=AdamWOptimizerArgs( + adam_eps=1e-08, + adam_beta1=0.9, + adam_beta2=0.95, + torch_adam_is_fused=True, + ), ) parallelism = ParallelismArgs( diff --git a/examples/config_tiny_llama.yaml b/examples/config_tiny_llama.yaml index d3ada238..0e87c663 100644 --- a/examples/config_tiny_llama.yaml +++ b/examples/config_tiny_llama.yaml @@ -1,6 +1,6 @@ checkpoints: checkpoint_interval: 10 - checkpoints_path: /fsx/nouamane/projects/nanotron/checkpoints + checkpoints_path: /fsx/ferdinandmom/ferdinand-hf/nanotron/checkpoints checkpoints_path_is_shared_file_system: false resume_checkpoint_path: null save_initial_state: false @@ -37,28 +37,24 @@ general: run: tiny_llama_%date_%jobid seed: 42 step: null -lighteval: null -logging: - iteration_step_info_interval: 1 - log_level: info - log_level_replica: info model: ddp_bucket_cap_mb: 25 dtype: bfloat16 init_method: std: 0.025 + # use_mup: true # uncomment this and comment the std line above to use spectral µTransfer make_vocab_size_divisible_by: 1 model_config: bos_token_id: 1 eos_token_id: 2 hidden_act: silu - hidden_size: 16 + hidden_size: 32 initializer_range: 0.02 - intermediate_size: 64 + intermediate_size: 128 is_llama_config: true max_position_embeddings: 256 num_attention_heads: 4 - num_hidden_layers: 2 + num_hidden_layers: 10 num_key_value_heads: 4 pad_token_id: null pretraining_tp: 1 @@ -69,30 +65,57 @@ model: vocab_size: 256 optimizer: accumulate_grad_in_fp32: true - adam_beta1: 0.9 - adam_beta2: 0.95 - adam_eps: 1.0e-08 clip_grad: 1.0 learning_rate_scheduler: - learning_rate: 0.0003 + learning_rate: 0.001 lr_decay_starting_step: null - lr_decay_steps: 8 + lr_decay_steps: null lr_decay_style: cosine - lr_warmup_steps: 2 + lr_warmup_steps: 2000 # 20% of the total steps lr_warmup_style: linear min_decay_lr: 1.0e-05 - torch_adam_is_fused: true + optimizer_factory: + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-08 + name: adamW + torch_adam_is_fused: true weight_decay: 0.01 zero_stage: 0 parallelism: dp: 2 expert_parallel_size: 1 - pp: 2 + pp: 1 pp_engine: 1f1b tp: 2 tp_linear_async_communication: true tp_mode: REDUCE_SCATTER -profiler: null +data_stages: + - name: Stable Training Stage + start_training_step: 1 + data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 1 + hf_dataset_config_name: null + hf_dataset_or_datasets: HuggingFaceH4/testing_alpaca_small + hf_dataset_splits: train + text_column_name: completion + num_loading_workers: 1 + seed: 42 + - name: Annealing Phase + start_training_step: 10 + data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 1 + hf_dataset_config_name: null + hf_dataset_or_datasets: HuggingFaceH4/testing_codealpaca_small + hf_dataset_splits: train + text_column_name: completion + num_loading_workers: 1 + seed: 42 +lighteval: null tokenizer: tokenizer_max_length: null tokenizer_name_or_path: gpt2 @@ -103,5 +126,16 @@ tokens: limit_val_batches: 0 micro_batch_size: 2 sequence_length: 32 - train_steps: 10 + train_steps: 15 val_check_interval: -1 +checkpoints: + checkpoint_interval: 10 + checkpoints_path: checkpoints + checkpoints_path_is_shared_file_system: false + resume_checkpoint_path: checkpoints + save_initial_state: false +profiler: null +logging: + iteration_step_info_interval: 1 + log_level: info + log_level_replica: info diff --git a/examples/contributor-guide/debug_config_tiny_llama.py b/examples/contributor-guide/debug_config_tiny_llama.py index e1e2d065..096995b0 100644 --- a/examples/contributor-guide/debug_config_tiny_llama.py +++ b/examples/contributor-guide/debug_config_tiny_llama.py @@ -5,6 +5,7 @@ CheckpointsArgs, Config, DataArgs, + DatasetStageArgs, GeneralArgs, LlamaConfig, LoggingArgs, @@ -95,7 +96,12 @@ optimizer=optimizer, logging=LoggingArgs(), tokens=tokens, - data=DataArgs(dataset=dataset, seed=seed), + data_stages=[ + DatasetStageArgs( + name="Stable Training Stage", start_training_step=1, data=DataArgs(dataset=dataset, seed=seed) + ), + DatasetStageArgs(name="Annealing Phase", start_training_step=10, data=DataArgs(dataset=dataset, seed=seed)), + ], profiler=None, ) diff --git a/examples/contributor-guide/debug_config_tiny_llama.yaml b/examples/contributor-guide/debug_config_tiny_llama.yaml index 27c24ed0..096a49b7 100644 --- a/examples/contributor-guide/debug_config_tiny_llama.yaml +++ b/examples/contributor-guide/debug_config_tiny_llama.yaml @@ -1,23 +1,34 @@ checkpoints: checkpoint_interval: 10 - checkpoints_path: /fsx/ferdinandmom/ferdinand-hf/nanotron/examples/checkpoints + checkpoints_path: /fsx/haojun/nanotron_latest/examples/checkpoints checkpoints_path_is_shared_file_system: false resume_checkpoint_path: null save_initial_state: false - data_stages: - - name: General purpose training - start_training_step: 1 - data: - dataset: - dataset_overwrite_cache: false - dataset_processing_num_proc_per_process: 1 - hf_dataset_config_name: null - hf_dataset_or_datasets: HuggingFaceH4/testing_alpaca_small - hf_dataset_splits: train - text_column_name: completion - num_loading_workers: 1 - seed: 42 +- data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 1 + hf_dataset_config_name: null + hf_dataset_or_datasets: HuggingFaceH4/testing_alpaca_small + hf_dataset_splits: train + text_column_name: completion + num_loading_workers: 1 + seed: 42 + name: Stable Training Stage + start_training_step: 1 +- data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 1 + hf_dataset_config_name: null + hf_dataset_or_datasets: HuggingFaceH4/testing_alpaca_small + hf_dataset_splits: train + text_column_name: completion + num_loading_workers: 1 + seed: 42 + name: Annealing Phase + start_training_step: 10 general: benchmark_csv_path: null consumed_train_samples: null diff --git a/examples/mamba/mamba.py b/examples/mamba/mamba.py index 20f86804..88ad85d2 100644 --- a/examples/mamba/mamba.py +++ b/examples/mamba/mamba.py @@ -181,7 +181,6 @@ def __init__( self.A_log = create_sharded_parameter_from_config( parameter=A_log, pg=self.tp_pg, split_config=SplitConfig(split_dim=0) ) - self.A_log._no_weight_decay = True # D "skip" parameter self.D = create_sharded_parameter_from_config( @@ -189,7 +188,6 @@ def __init__( pg=self.tp_pg, split_config=SplitConfig(split_dim=0), ) - self.D._no_weight_decay = True # self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) self.out_proj = TensorParallelRowLinear( @@ -806,6 +804,14 @@ def forward( label_mask=label_mask, )["loss"] return {"loss": loss} + + def get_named_params_without_weight_decay(self): + # get full name with "A_log", "D" + named_param_without_weight_decay = [] + for name, _ in self.model.named_parameters(): + if "A_log" in name or "D" in name: + named_param_without_weight_decay.append(name) + return named_param_without_weight_decay @torch.no_grad() def init_model_randomly(self, config): @@ -916,11 +922,7 @@ def init_model_randomly(self, config): raise ValueError(f"Who the fuck is {param_name}?") elif isinstance(module, Mamba): - # NOTE(fmom): nn.Parameter are initialized in Mamba __init__ - # In Mamba, only those 3 parameters don't have weight decay. - if param_name in ["dt_bias", "A_log", "D"]: - param._no_weight_decay = True - + pass else: raise Exception(f"Parameter {full_param_name} was not initialized") diff --git a/examples/moe/config_llamoe.py b/examples/moe/config_llamoe.py index ad1deec2..c1f314ea 100644 --- a/examples/moe/config_llamoe.py +++ b/examples/moe/config_llamoe.py @@ -4,6 +4,7 @@ from typing import Optional from nanotron.config import ( + AdamWOptimizerArgs, CheckpointsArgs, Config, DataArgs, @@ -99,11 +100,13 @@ def __post_init__(self): weight_decay=0.01, clip_grad=1.0, accumulate_grad_in_fp32=False, - adam_eps=1e-08, - adam_beta1=0.9, - adam_beta2=0.95, - torch_adam_is_fused=True, learning_rate_scheduler=learning_rate, + optimizer_factory=AdamWOptimizerArgs( + adam_eps=1e-08, + adam_beta1=0.9, + adam_beta2=0.95, + torch_adam_is_fused=True, + ), ) parallelism = ParallelismArgs( diff --git a/examples/moe/config_llamoe.yaml b/examples/moe/config_llamoe.yaml index 1b312129..46dc0534 100644 --- a/examples/moe/config_llamoe.yaml +++ b/examples/moe/config_llamoe.yaml @@ -5,18 +5,30 @@ checkpoints: resume_checkpoint_path: /fsx/nouamane/projects/nanotron/examples/checkpoints save_initial_state: true data_stages: - - name: General purpose training - start_training_step: 1 - data: - dataset: - dataset_overwrite_cache: false - dataset_processing_num_proc_per_process: 12 - hf_dataset_config_name: null - hf_dataset_or_datasets: roneneldan/TinyStories - hf_dataset_splits: train - text_column_name: text - num_loading_workers: 1 - seed: 42 +- data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 12 + hf_dataset_config_name: null + hf_dataset_or_datasets: roneneldan/TinyStories + hf_dataset_splits: train + text_column_name: text + num_loading_workers: 1 + seed: 42 + name: Stable Training Stage + start_training_step: 1 +- data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 12 + hf_dataset_config_name: null + hf_dataset_or_datasets: roneneldan/TinyStories + hf_dataset_splits: train + text_column_name: text + num_loading_workers: 1 + seed: 42 + name: Annealing Phase + start_training_step: 10 general: benchmark_csv_path: null consumed_train_samples: null @@ -60,9 +72,6 @@ model: vocab_size: 32000 optimizer: accumulate_grad_in_fp32: false - adam_beta1: 0.9 - adam_beta2: 0.95 - adam_eps: 1.0e-08 clip_grad: 1.0 learning_rate_scheduler: learning_rate: 0.0003 @@ -72,7 +81,12 @@ optimizer: lr_warmup_steps: 100 lr_warmup_style: linear min_decay_lr: 1.0e-05 - torch_adam_is_fused: true + optimizer_factory: + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-08 + name: adamW + torch_adam_is_fused: true weight_decay: 0.01 zero_stage: 0 parallelism: diff --git a/examples/moe/llamoe.py b/examples/moe/llamoe.py index fb274ad4..2c1a8d91 100644 --- a/examples/moe/llamoe.py +++ b/examples/moe/llamoe.py @@ -14,7 +14,7 @@ # limitations under the License. """ PyTorch LLaMa MoE model.""" import math -from typing import Dict, Optional, Union +from typing import Dict, Optional, Union, List import torch from config_llamoe import LlaMoEConfig @@ -915,7 +915,7 @@ def init_model_randomly(self, config): else name for name, param in model.named_parameters() }, f"Somehow the initialized set of parameters don't match:\n - Expected: { {name for name, _ in model.named_parameters()} }\n - Got: {initialized_parameters}" - + def get_block_compute_costs(self): """Computes the compute cost of each block in the model so that we can do a better job of load balancing.""" return self.model.get_block_compute_costs() diff --git a/examples/moe/requirements.txt b/examples/moe/requirements.txt index 20b2778d..b32c55b4 100644 --- a/examples/moe/requirements.txt +++ b/examples/moe/requirements.txt @@ -1 +1,2 @@ +stanford-stk>=0.0.6 megablocks==0.5.1 diff --git a/examples/mup/README.md b/examples/mup/README.md new file mode 100644 index 00000000..c86850ca --- /dev/null +++ b/examples/mup/README.md @@ -0,0 +1,34 @@ +OpenAI's scaling laws [[link]](https://arxiv.org/abs/2001.08361) in 2020 has shown that scaling is one of the core ingredients for the success of LLMs. But naively stacking more layers can lead to unstable training due to exploding or vanishing gradients. In our implementation, the experimental results show that in a 350m llama, spectral µTransfer matches the pretraining performance of the baseline (albeit with a slightly higher training loss of 0.04). In another MLP-only experiment, µTransfer maintains a consistent L1 norm of activation across widths, and depths and allows scaling up to 2B while the SP baseline blows up and becomes untrainable. + + +# How to use Spectral µTransfer +In your Nanotron configuration, simply set `use_mup` to `true`. Nanotron will automatically determine the right standard deviation and learning rate for each parameter. + + +```diff +model: + ... + init_method: +- std: 0.025 ++ use_mup: true +``` + +# MLP Only Experiment + +We ran a systematic experiment varying the number of layers from 8 to 32, width from 128 to 8192, and batch size from 32 to 2048, all on a logarithmic scale, CIFAR dataset, using an MSE training objective for 4 epochs with Adam optimizer. [[Experiment Report]](https://wandb.ai/neuralink/exp14_mup_grid_search/reports/-Spectral-Transfer-MLP-s-Experiment-Results--Vmlldzo3NDQ0NTQw?accessToken=xe0mkunx3y8t0xzbzxu9caqcre57or5la58d9o209hinanlmzoaj7es24m4elvdj) + + +![Scale across widths](./assets/scale-across-width.png) + + + +![Scale across depths](./assets/scale-across-depth.png) + + +# On 350m LLaMA + +We trained a 350m model with spectral µTransfer and standard parametrization using Nanotron, a global batch size of 1m tokens at a learning rate of 0.001. µTransfer matches the performance of standard parametrization, with a slightly higher training loss of 0.04. [[Experiment Report]](https://api.wandb.ai/links/neuralink/i70nnpu9) + +Please check the directory [[./examples/mup/configs]](/examples/mup/configs) for the configurations we used to reproduce the experiments. + +![LLaMA](./assets/llama.png) diff --git a/examples/mup/assets/llama.png b/examples/mup/assets/llama.png new file mode 100644 index 00000000..c2ce0d9c Binary files /dev/null and b/examples/mup/assets/llama.png differ diff --git a/examples/mup/assets/scale-across-depth.png b/examples/mup/assets/scale-across-depth.png new file mode 100644 index 00000000..9c6e6276 Binary files /dev/null and b/examples/mup/assets/scale-across-depth.png differ diff --git a/examples/mup/assets/scale-across-width.png b/examples/mup/assets/scale-across-width.png new file mode 100644 index 00000000..457a2cbe Binary files /dev/null and b/examples/mup/assets/scale-across-width.png differ diff --git a/examples/mup/configs/mup_350m_llama_config.yaml b/examples/mup/configs/mup_350m_llama_config.yaml new file mode 100644 index 00000000..9a0402e8 --- /dev/null +++ b/examples/mup/configs/mup_350m_llama_config.yaml @@ -0,0 +1,141 @@ +checkpoints: + checkpoint_interval: 10000 + checkpoints_path: checkpoints + checkpoints_path_is_shared_file_system: false + resume_checkpoint_path: null + save_initial_state: false + +data_stages: + - name: Stable Training Stage + start_training_step: 1 + data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 1 + hf_dataset_config_name: null + hf_dataset_or_datasets: roneneldan/TinyStories + hf_dataset_splits: train + text_column_name: text + num_loading_workers: 1 + seed: 42 + - name: Annealing Phase + start_training_step: 9000 + data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 1 + hf_dataset_config_name: null + hf_dataset_or_datasets: HuggingFaceH4/testing_alpaca_small + hf_dataset_splits: train + text_column_name: completion + num_loading_workers: 1 + seed: 42 + +general: + benchmark_csv_path: null + consumed_train_samples: null + ignore_sanity_checks: true + project: debug + run: llama_350m_mup + seed: 42 + step: null +logging: + iteration_step_info_interval: 1 + log_level: debug + log_level_replica: info +model: + ddp_bucket_cap_mb: 120 + dtype: bfloat16 + init_method: + use_mup: true + make_vocab_size_divisible_by: 1 + model_config: + bos_token_id: 1 + eos_token_id: 2 + hidden_act: silu + initializer_range: 0.02 + + hidden_size: 1024 + intermediate_size: 4096 + num_hidden_layers: 14 + + is_llama_config: true + max_position_embeddings: 1024 + num_attention_heads: 8 + num_key_value_heads: 4 + pad_token_id: null + pretraining_tp: 1 + rms_norm_eps: 1.0e-05 + rope_scaling: null + tie_word_embeddings: false + use_cache: true + vocab_size: 49152 +optimizer: + accumulate_grad_in_fp32: false + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-08 + clip_grad: 1.0 + learning_rate_scheduler: + learning_rate: 0.001 + lr_decay_starting_step: null + lr_decay_steps: null + lr_decay_style: cosine + lr_warmup_steps: 100 # 10% warm up of total training steps + lr_warmup_style: linear + min_decay_lr: 1.0e-05 + torch_adam_is_fused: true + weight_decay: 0.1 + zero_stage: 0 +parallelism: + dp: 4 + pp: 1 + pp_engine: 1f1b + tp: 2 + tp_linear_async_communication: true + tp_mode: REDUCE_SCATTER +profiler: null +tokenizer: + tokenizer_max_length: null + tokenizer_name_or_path: gpt2 + tokenizer_revision: null +tokens: + batch_accumulation_per_replica: 8 + limit_test_batches: 0 + limit_val_batches: 0 + micro_batch_size: 32 + sequence_length: 1024 + train_steps: 440 + val_check_interval: -1 +lighteval: + batch_size: 16 + checkpoints_path: null + generation: null + logging: + hub_repo_details: null + hub_repo_results: null + # hub_repo_tensorboard: HuggingFaceBR4/fmom-mamba2 + local_output_path: /fsx/phuc/new_workspace/experiments/mup_for_mamba2/test_mamba350M_tp4_917cfc66/logs + push_details_to_hub: null + push_results_to_hub: null + push_results_to_tensorboard: true + tensorboard_metric_prefix: e + parallelism: + dp: 2 + expert_parallel_size: 1 + pp: 1 + pp_engine: 1f1b + tp: 2 + tp_linear_async_communication: false + tp_mode: ALL_REDUCE + # slurm_script_dir: /fsx/phuc/new_workspace/experiments/mup_for_mamba2/test_mamba350M_tp4_917cfc66/lighteval/slurm_scripts + # slurm_template: /fsx/phuc/new_workspace/experiments/mup_for_mamba2/test_mamba350M_tp4_917cfc66/run_eval.slurm.jinja + tasks: + # custom_tasks: brrr.lighteval.custom_tasks + dataset_loading_processes: 8 + max_samples: 1000 + multichoice_continuations_start_space: null + no_multichoice_continuations_start_space: null + num_fewshot_seeds: null + tasks: early-signal + wandb: null diff --git a/examples/mup/configs/sp_350m_llama_config.yaml b/examples/mup/configs/sp_350m_llama_config.yaml new file mode 100644 index 00000000..5bcdfdbb --- /dev/null +++ b/examples/mup/configs/sp_350m_llama_config.yaml @@ -0,0 +1,108 @@ +checkpoints: + checkpoint_interval: 10000 + checkpoints_path: checkpoints + checkpoints_path_is_shared_file_system: false + resume_checkpoint_path: null + save_initial_state: false + +data_stages: + - name: Stable Training Stage + start_training_step: 1 + data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 1 + hf_dataset_config_name: null + hf_dataset_or_datasets: roneneldan/TinyStories + hf_dataset_splits: train + text_column_name: text + num_loading_workers: 1 + seed: 42 + - name: Annealing Phase + start_training_step: 9000 + data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 1 + hf_dataset_config_name: null + hf_dataset_or_datasets: HuggingFaceH4/testing_alpaca_small + hf_dataset_splits: train + text_column_name: completion + num_loading_workers: 1 + seed: 42 + +general: + benchmark_csv_path: null + consumed_train_samples: null + ignore_sanity_checks: true + project: debug + run: llama_350m_sp + seed: 42 + step: null +lighteval: null +logging: + iteration_step_info_interval: 1 + log_level: info + log_level_replica: info +model: + ddp_bucket_cap_mb: 120 + dtype: bfloat16 + init_method: + std: 0.03125 # 1/sqrt(1024)=0.022097086912079608 + make_vocab_size_divisible_by: 1 + model_config: + bos_token_id: 1 + eos_token_id: 2 + hidden_act: silu + initializer_range: 0.02 + hidden_size: 1024 + intermediate_size: 4096 + num_hidden_layers: 14 + is_llama_config: true + max_position_embeddings: 1024 + num_attention_heads: 8 + num_key_value_heads: 4 + pad_token_id: null + pretraining_tp: 1 + rms_norm_eps: 1.0e-05 + rope_scaling: null + tie_word_embeddings: false + use_cache: true + vocab_size: 49152 +optimizer: + accumulate_grad_in_fp32: false + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-08 + clip_grad: 1.0 + learning_rate_scheduler: + learning_rate: 0.001 + lr_decay_starting_step: null + lr_decay_steps: null + lr_decay_style: cosine + lr_warmup_steps: 100 # 10% warm up of total training steps + lr_warmup_style: linear + min_decay_lr: 1.0e-05 + torch_adam_is_fused: true + weight_decay: 0.1 + zero_stage: 0 +parallelism: + dp: 4 + pp: 1 + pp_engine: 1f1b + tp: 2 + tp_linear_async_communication: true + tp_mode: REDUCE_SCATTER +profiler: null +tokenizer: + tokenizer_max_length: null + tokenizer_name_or_path: gpt2 + tokenizer_revision: null +tokens: + batch_accumulation_per_replica: 8 + limit_test_batches: 0 + limit_val_batches: 0 + micro_batch_size: 32 + sequence_length: 1024 + train_steps: 440 + val_check_interval: -1 diff --git a/run_train.py b/run_train.py index e5be3048..8dc16f7a 100644 --- a/run_train.py +++ b/run_train.py @@ -22,12 +22,14 @@ get_datasets, get_train_dataloader, ) +from nanotron.helpers import ( + compute_remain_train_steps_of_a_data_stage_from_ckp, + get_consumed_train_samples_of_a_data_stage_from_ckp, +) from nanotron.logging import log_rank from nanotron.parallel.pipeline_parallel.utils import get_input_output_pp_ranks from nanotron.trainer import DistributedTrainer -from nanotron.utils import ( - main_rank_first, -) +from nanotron.utils import main_rank_first from torch.utils.data import DataLoader try: @@ -41,8 +43,21 @@ logger = logging.get_logger(__name__) -def get_dataloader_from_data_stage(trainer: DistributedTrainer, data: DataArgs): - """Returns a dataloader for training.""" +def get_dataloader_from_data_stage( + trainer: DistributedTrainer, + data: DataArgs, + consumed_train_samples: int, + num_remaining_train_steps: int, +): + """ + Returns a dataloader for a given data stage. + + data: The data configuration for the current stage. + consumed_train_samples: The number of samples consumed by the model in the this stage (each stage starts from zero). + num_remaining_train_steps: The number of remaining training steps for this stage. + """ + assert consumed_train_samples >= 0, "consumed_train_samples should be greater than 0" + assert num_remaining_train_steps >= 0, "num_remaining_train_steps should be greater than 0" # First, we need to know which ranks to feed the dataloader to input_pp_rank, output_pp_rank = get_input_output_pp_ranks(model=trainer.model) @@ -79,6 +94,7 @@ def get_dataloader_from_data_stage(trainer: DistributedTrainer, data: DataArgs): # We load the raw dataset raw_dataset = get_datasets( hf_dataset_or_datasets=data.dataset.hf_dataset_or_datasets, + hf_dataset_config_name=data.dataset.hf_dataset_config_name, splits=data.dataset.hf_dataset_splits, )["train"] @@ -104,17 +120,16 @@ def get_dataloader_from_data_stage(trainer: DistributedTrainer, data: DataArgs): input_pp_rank=input_pp_rank, output_pp_rank=output_pp_rank, micro_batch_size=trainer.micro_batch_size, - consumed_train_samples=trainer.consumed_train_samples, + consumed_train_samples=consumed_train_samples, dataloader_num_workers=data.num_loading_workers, seed_worker=data.seed, dataloader_drop_last=True, ) + # Check if we have enough samples for train_steps total_tokens_dataset = len(dataloader.dataset) * trainer.sequence_length num_tokens_needed_for_training = ( - (trainer.config.tokens.train_steps - trainer.start_iteration_step) - * trainer.global_batch_size - * trainer.sequence_length + num_remaining_train_steps * trainer.global_batch_size * trainer.sequence_length ) assert num_tokens_needed_for_training <= total_tokens_dataset, ( f"Dataset is too small for steps ({total_tokens_dataset} < {num_tokens_needed_for_training}), " @@ -127,16 +142,41 @@ def get_dataloader_from_data_stage(trainer: DistributedTrainer, data: DataArgs): def get_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]: - sorted_stages = sorted(trainer.config.data_stages, key=lambda stage: stage.start_training_step) dataloaders = {} - for idx, stage in enumerate(sorted_stages): + + for stage_idx, stage in enumerate(trainer.config.data_stages): # NOTE: we only create the dataloader for the first stage, # then we lazy initialize the dataloader for the other stages stage = cast(DatasetStageArgs, stage) + consumed_train_samples = get_consumed_train_samples_of_a_data_stage_from_ckp(stage, trainer.metadata) + assert ( + consumed_train_samples is not None + ), f"Cannot find consumed_train_samples for stage {stage.start_training_step} in the checkpoint" + + num_remaining_train_steps = compute_remain_train_steps_of_a_data_stage_from_ckp( + stage, trainer.config, trainer.metadata + ) + log_rank( + f"[Training Plan] Stage {stage.name} has {num_remaining_train_steps} remaining training steps and has consumed {consumed_train_samples} samples", + logger=logger, + level=logging.INFO, + rank=0, + ) + dataloader = ( - get_dataloader_from_data_stage(trainer, stage.data) - if idx == 0 - else lambda stage=stage: get_dataloader_from_data_stage(trainer, stage.data) + get_dataloader_from_data_stage( + trainer, + stage.data, + consumed_train_samples=consumed_train_samples, + num_remaining_train_steps=num_remaining_train_steps, + ) + if stage_idx == 0 + else lambda stage=stage: get_dataloader_from_data_stage( + trainer, + stage.data, + consumed_train_samples=consumed_train_samples, + num_remaining_train_steps=num_remaining_train_steps, + ) ) dataloaders[stage.name] = dataloader return dataloaders diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index bf816ed1..d9946f26 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -11,11 +11,7 @@ from yaml.loader import SafeLoader from nanotron.config.lighteval_config import LightEvalConfig -from nanotron.config.models_config import ( - ExistingCheckpointInit, - NanotronConfigs, - RandomInit, -) +from nanotron.config.models_config import ExistingCheckpointInit, NanotronConfigs, RandomInit, SpectralMupInit from nanotron.config.parallelism_config import ParallelismArgs from nanotron.config.utils_config import ( RecomputeGranularity, @@ -189,7 +185,7 @@ class ModelArgs: """Arguments related to model architecture""" model_config: NanotronConfigs - init_method: Union[RandomInit, ExistingCheckpointInit] + init_method: Union[RandomInit, SpectralMupInit, ExistingCheckpointInit] dtype: Optional[torch.dtype] = None make_vocab_size_divisible_by: int = 1 ddp_bucket_cap_mb: int = 25 @@ -200,6 +196,8 @@ def __post_init__(self): if isinstance(self.dtype, str): self.dtype = cast_str_to_torch_dtype(self.dtype) + self.model_config._is_using_mup = isinstance(self.init_method, SpectralMupInit) + # if self.model_config.max_position_embeddings is None: # self.model_config.max_position_embeddings = 0 @@ -264,20 +262,29 @@ def __post_init__(self): self.min_decay_lr = self.learning_rate +@dataclass +class SGDOptimizerArgs: + name: str = "sgd" + + +@dataclass +class AdamWOptimizerArgs: + adam_eps: float + adam_beta1: float + adam_beta2: float + torch_adam_is_fused: bool + name: str = "adamW" + + @dataclass class OptimizerArgs: """Arguments related to the optimizer and learning rate""" + optimizer_factory: Union[SGDOptimizerArgs, AdamWOptimizerArgs] zero_stage: int weight_decay: float clip_grad: Optional[float] - accumulate_grad_in_fp32: bool - - adam_eps: float - adam_beta1: float - adam_beta2: float - torch_adam_is_fused: bool learning_rate_scheduler: LRSchedulerArgs @@ -331,6 +338,7 @@ def __post_init__(self): ) if self.data_stages is not None: + self.data_stages = sorted(self.data_stages, key=lambda stage: stage.start_training_step) names = [stage.name for stage in self.data_stages] training_steps = [stage.start_training_step for stage in self.data_stages] assert any( @@ -346,6 +354,12 @@ def __post_init__(self): f"Each stage should have unique starting training step, please change the starting training step for stage {stage.name}" ) + # NOTE: must order the stages by start_training_step from lowest to highest + assert all( + self.data_stages[i].start_training_step < self.data_stages[i + 1].start_training_step + for i in range(len(self.data_stages) - 1) + ), "The stages are not sorted by start_training_step in increasing order" + # # if lighteval, we need tokenizer to be defined # if self.checkpoints.lighteval is not None: # assert self.tokenizer.tokenizer_name_or_path is not None diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index 610acc06..ba4559cf 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -8,6 +8,16 @@ class RandomInit: std: float +@dataclass +class SpectralMupInit: + """This is used to initialize the model with spectral mup. Set it to True to use it.""" + + use_mup: bool + + def __post_init__(self): + assert self.use_mup, "Remove `use_mup` if you don't want to use it" + + @dataclass class ExistingCheckpointInit: """This is used to initialize from an already existing model (without optimizer, lr_scheduler...)""" @@ -42,10 +52,19 @@ class LlamaConfig: vocab_size: int = 32000 def __post_init__(self): + # NOTE: user don't set self._init_method, ModelArgs will set it + # then we only pass LlamaConfig around + self._is_using_mup: bool = False + # self._init_method: Optional[Union[RandomInit, SpectralMupInit, ExistingCheckpointInit]] = None + # for backward compatibility if self.num_key_value_heads is None: self.num_key_value_heads = self.num_attention_heads + @property + def is_using_mup(self) -> bool: + return self._is_using_mup + @dataclass class Starcoder2Config: diff --git a/src/nanotron/constants.py b/src/nanotron/constants.py index 84f6b55b..48332059 100644 --- a/src/nanotron/constants.py +++ b/src/nanotron/constants.py @@ -2,6 +2,11 @@ from packaging.version import Version, parse -CHECKPOINT_VERSION = Version("1.2") +CHECKPOINT_VERSION = Version("1.3") PY_VERSION = parse(platform.python_version()) + +#### FOR SERIALIZATION #### + +CHECKPOINT_FILE_NAME = "checkpoint_metadata.json" +MODEL_CONFIG_FILE_NAME = "model_config.json" diff --git a/src/nanotron/dataloader.py b/src/nanotron/dataloader.py index d8c7885a..b451ec66 100644 --- a/src/nanotron/dataloader.py +++ b/src/nanotron/dataloader.py @@ -85,6 +85,7 @@ def sanity_check_dataloader( # Adapted from h4/src/h4/data/loading.py def get_datasets( hf_dataset_or_datasets: Union[dict, str], + hf_dataset_config_name: Optional[str] = None, splits: Optional[Union[List[str], str]] = ["train", "test"], ) -> "DatasetDict": """ @@ -116,6 +117,9 @@ def get_datasets( for split in splits: raw_datasets[split] = load_dataset( hf_dataset_or_datasets, + # NOTE: weird shit, I can't pass config_name=config_name + # have to pass it as positional arguments!! + hf_dataset_config_name, split=split, ) else: diff --git a/src/nanotron/helpers.py b/src/nanotron/helpers.py index 116cf653..f7bf63e5 100644 --- a/src/nanotron/helpers.py +++ b/src/nanotron/helpers.py @@ -5,20 +5,20 @@ import os import time from datetime import datetime +from functools import partial from math import ceil -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple import numpy as np import torch from torch import nn from torch.nn.parallel import DistributedDataParallel -from torch.optim import AdamW from torch.optim.lr_scheduler import LambdaLR from torch.profiler import ProfilerActivity, profile, tensorboard_trace_handler from nanotron import distributed as dist from nanotron import logging -from nanotron.config import Config, LRSchedulerArgs, OptimizerArgs, ParallelismArgs +from nanotron.config import Config, DatasetStageArgs, LRSchedulerArgs, OptimizerArgs, ParallelismArgs from nanotron.distributed import ProcessGroup from nanotron.logging import LogItem, log_rank from nanotron.models.base import NanotronModel @@ -41,6 +41,8 @@ get_current_random_state, get_synced_random_state, ) +from nanotron.scaling.parametrization import LearningRateForSP, LearningRateForSpectralMup, ParametrizationMethod +from nanotron.serialize.metadata import TrainingMetadata logger = logging.get_logger(__name__) @@ -91,8 +93,17 @@ def lr_scheduler_builder(optimizer: Optimizer, lr_scheduler_args: LRSchedulerArg else: lr_decay_starting_step = lr_scheduler_args.lr_decay_starting_step - def lr_lambda(current_step: int): - """LR Scheduling function, it has from 2 up to 4 phases: + def lr_lambda(current_step: int, initial_lr: float): + """ + current_step: current training step + initial_lr: the learning rate of a parameter group + + More info on initial_lr: + And in standard parameterization, lr_lambda only takes a single learning rate. + But in µTransfer, each parameter has a custom learning rate (custom_lr = lr_scheduler_args.learning_rate * scaling_factor), + so each parameter group has a custom lr_lambda function. + + LR Scheduling function, it has from 2 up to 4 phases: - warmup, - optional: constant (if lr_decay_starting_step is set) - decay @@ -104,12 +115,12 @@ def lr_lambda(current_step: int): """ # No warmup or decay if lr_scheduler_args.lr_warmup_steps == 0 and lr_decay_steps == 0: - return lr_scheduler_args.learning_rate + return initial_lr # Warmup phase elif lr_scheduler_args.lr_warmup_style is not None and current_step <= lr_scheduler_args.lr_warmup_steps: if lr_scheduler_args.lr_warmup_style == "linear": - lmbda = lr_scheduler_args.learning_rate * current_step / max(lr_scheduler_args.lr_warmup_steps, 1) + lmbda = initial_lr * current_step / max(lr_scheduler_args.lr_warmup_steps, 1) elif lr_scheduler_args.lr_warmup_style == "constant": lmbda = lr_scheduler_args.learning_rate else: @@ -117,21 +128,21 @@ def lr_lambda(current_step: int): # Optional constant phase at learning_rate elif current_step < lr_decay_starting_step: - lmbda = lr_scheduler_args.learning_rate + lmbda = initial_lr # Decay phase elif lr_scheduler_args.lr_decay_style is not None and current_step < lr_decay_starting_step + lr_decay_steps: if lr_scheduler_args.lr_decay_style == "cosine": lmbda = ( lr_scheduler_args.min_decay_lr - + (lr_scheduler_args.learning_rate - lr_scheduler_args.min_decay_lr) + + (initial_lr - lr_scheduler_args.min_decay_lr) * (1 + math.cos(math.pi * (current_step - lr_decay_starting_step) / lr_decay_steps)) / 2 ) elif lr_scheduler_args.lr_decay_style == "linear": lmbda = ( lr_scheduler_args.min_decay_lr - + (lr_scheduler_args.learning_rate - lr_scheduler_args.min_decay_lr) + + (initial_lr - lr_scheduler_args.min_decay_lr) * (lr_decay_steps - (current_step - lr_decay_starting_step)) / lr_decay_steps ) @@ -142,15 +153,146 @@ def lr_lambda(current_step: int): else: lmbda = lr_scheduler_args.min_decay_lr - lmbda /= lr_scheduler_args.learning_rate # Normalization for pytorch + lmbda /= initial_lr # Normalization for pytorch return lmbda - lr_scheduler = LambdaLR(optimizer.get_base_optimizer(), lr_lambda=lr_lambda) + def get_lr_lambda_for_param_group(lr: float): + return partial(lr_lambda, initial_lr=lr) + + # NOTE: get learning rate scheduler for each param group + lr_lambdas = [] + for param_group in optimizer.get_base_optimizer().param_groups: + lr_lambdas.append(get_lr_lambda_for_param_group(lr=param_group["lr"])) + + assert len(lr_lambdas) == len( + optimizer.get_base_optimizer().param_groups + ), "Custom learning rate functions dont match the number of param groups" + + log_rank( + f"[Optimizer Building] There are total {len(lr_lambdas)} custom learning rate function for parameter groups", + logger=logger, + level=logging.DEBUG, + ) + + lr_scheduler = LambdaLR(optimizer.get_base_optimizer(), lr_lambda=lr_lambdas) return lr_scheduler +def get_custom_weight_decay_for_named_parameters( + named_parameters: Iterable[Tuple[str, torch.Tensor]], + model: NanotronModel, + module_id_to_prefix: Dict[int, str], + weight_decay: float, +) -> List[Dict[str, Any]]: + """ + Apply weight decay to all parameters except the ones that are in the named_param_without_weight_decay list. + """ + + named_param_groups_with_custom_weight_decay = [] + + exclude_named_params = model.get_named_params_without_weight_decay() + + for name, param in named_parameters: + if param.is_tied: + param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix) + else: + pass + + if any(name.endswith(substring) for substring in exclude_named_params): + named_param_groups_with_custom_weight_decay.append({"named_params": [(name, param)], "weight_decay": 0.0}) + else: + named_param_groups_with_custom_weight_decay.append( + {"named_params": [(name, param)], "weight_decay": weight_decay} + ) + + log_rank( + f"[Optimizer Building] Creating {len(named_param_groups_with_custom_weight_decay)} param groups with custom weight decay", + logger=logger, + level=logging.DEBUG, + ) + return named_param_groups_with_custom_weight_decay + + +def get_custom_lr_for_named_parameters( + parametrization_method: ParametrizationMethod, + lr: float, + named_parameters: Iterable[Tuple[str, torch.Tensor]], + model: NanotronModel, +) -> List[Dict[str, Any]]: + """ + Get custom learning rates for parameters based on the parametrization method. + + NOTE: in some paramtrization methods, we use a global learning rate for all parameters, + in others we use a custom learning rate for each parameter (eg: spectral µTransfer). + """ + + assert parametrization_method in [ParametrizationMethod.SPECTRAL_MUP, ParametrizationMethod.STANDARD] + + lr_mapper_cls = ( + LearningRateForSpectralMup + if parametrization_method == ParametrizationMethod.SPECTRAL_MUP + else LearningRateForSP + ) + + log_rank( + f"[Optimizer Building] Using {lr_mapper_cls.__name__} as learning rate", + logger=logger, + level=logging.INFO, + rank=0, + ) + + # NOTE: since in the case of pipeline parallelism, each rank only has a subset of the model + # so we only get the parameters that are in the current rank + learning_rate_mapper = lr_mapper_cls(names_to_modules=model.named_modules_in_pp_rank, lr=lr) + + named_param_groups_with_custom_lr = [] + for ( + name, + param, + ) in named_parameters: + learning_rate = learning_rate_mapper.get_lr(name, param) + assert isinstance(learning_rate, float), f"Expected a float, got {learning_rate} for parameter {name}" + named_param_groups_with_custom_lr.append({"named_params": [(name, param)], "lr": learning_rate}) + + log_rank( + f"[Optimizer Building] Creating {len(named_param_groups_with_custom_lr)} param groups with custom learning rates", + logger=logger, + level=logging.DEBUG, + ) + + return named_param_groups_with_custom_lr + + +def merge_named_param_groups( + named_param_groups_with_lr: List[Dict[str, Any]], + named_param_groups_with_weight_decay: List[Dict[str, Any]], +) -> List[Dict[str, Any]]: + + assert len(named_param_groups_with_lr) == len( + named_param_groups_with_weight_decay + ), "Named param groups don't match in length" + + named_param_groups = [] + for group_with_lr, group_with_weight_decay in zip( + named_param_groups_with_lr, named_param_groups_with_weight_decay + ): + assert group_with_lr["named_params"] == group_with_weight_decay["named_params"] + named_param_groups.append( + { + "named_params": group_with_lr["named_params"], + "lr": group_with_lr["lr"], + "weight_decay": group_with_weight_decay["weight_decay"], + } + ) + + return named_param_groups + + def init_optimizer_and_grad_accumulator( - model: nn.Module, optimizer_args: OptimizerArgs, parallel_context: ParallelContext + parametrization_method: ParametrizationMethod, + model: nn.Module, + optimizer_args: OptimizerArgs, + parallel_context: ParallelContext, ) -> Tuple[BaseOptimizer, GradientAccumulator]: # Unwrap DDP unwrapped_model: NanotronModel = model.module if isinstance(model, DistributedDataParallel) else model @@ -161,18 +303,52 @@ def init_optimizer_and_grad_accumulator( named_parameters = list(unwrapped_model.get_named_params_with_correct_tied()) + named_param_groups_with_lr = get_custom_lr_for_named_parameters( + parametrization_method=parametrization_method, + named_parameters=named_parameters, + model=unwrapped_model, + lr=optimizer_args.learning_rate_scheduler.learning_rate, + ) + named_param_groups_with_weight_decay = get_custom_weight_decay_for_named_parameters( + named_parameters=named_parameters, + model=unwrapped_model, + module_id_to_prefix=module_id_to_prefix, + weight_decay=optimizer_args.weight_decay, + ) + + named_param_groups = merge_named_param_groups(named_param_groups_with_lr, named_param_groups_with_weight_decay) + # Basic optimizer builder def basic_optimizer_builder(named_param_groups): + optimizer = None + + if optimizer_args.optimizer_factory.name == "adamW": + + def optimizer(param_groups): + return torch.optim.AdamW( + param_groups, + lr=optimizer_args.learning_rate_scheduler.learning_rate, + weight_decay=optimizer_args.weight_decay, + eps=optimizer_args.optimizer_factory.adam_eps, + betas=(optimizer_args.optimizer_factory.adam_beta1, optimizer_args.optimizer_factory.adam_beta2), + fused=optimizer_args.optimizer_factory.torch_adam_is_fused, + ) + + elif optimizer_args.optimizer_factory.name == "sgd": + + def optimizer(param_groups): + return torch.optim.SGD( + param_groups, + lr=optimizer_args.learning_rate_scheduler.learning_rate, + weight_decay=optimizer_args.weight_decay, + ) + + else: + raise ValueError(f"Optimizer {optimizer_args.optimizer_factory.name} is not supported") + return NamedOptimizer( named_params_or_groups=named_param_groups, - optimizer_builder=lambda param_groups: AdamW( # pylint: disable=E0601 - param_groups, - weight_decay=optimizer_args.weight_decay, - lr=optimizer_args.learning_rate_scheduler.learning_rate, - eps=optimizer_args.adam_eps, - betas=(optimizer_args.adam_beta1, optimizer_args.adam_beta2), - fused=optimizer_args.torch_adam_is_fused, - ), + optimizer_builder=optimizer, ) optimizer_builder = basic_optimizer_builder @@ -202,7 +378,7 @@ def grad_optimizer_builder(named_param_groups): if optimizer_args.zero_stage > 0: # Build optimizer optimizer = ZeroDistributedOptimizer( - named_params_or_groups=named_parameters, + named_params_or_groups=named_param_groups, # TODO @thomasw21: We need a better API for gradient accumulation/zero etc ... optimizer_builder=optimizer_builder, dp_pg=parallel_context.dp_pg, @@ -220,7 +396,7 @@ def grad_optimizer_builder(named_param_groups): assert param.data_ptr() == optim_model_param.data_ptr() else: # Build optimizer - optimizer = optimizer_builder(named_parameters) + optimizer = optimizer_builder(named_param_groups) if grad_accumulator is not None and optimizer_args.zero_stage > 0: # There's a way to only require to reduce_scatter the gradients instead of all_reducing @@ -259,7 +435,8 @@ def grad_optimizer_builder(named_param_groups): def test_equal_dict(first: Dict, second: Dict, sub_paths: Optional[List[str]] = None) -> None: - """Raise if doesn't match""" + """Raise if doesn't match.""" + if sub_paths is None: sub_paths = [] @@ -493,3 +670,39 @@ def log_throughput( if dist.get_rank(parallel_context.world_pg) == 0: write_to_csv(config.general.benchmark_csv_path, table_log, model_tflops, slurm_job_id) + + +def compute_remain_train_steps_of_a_data_stage_from_ckp( + stage: DatasetStageArgs, config: Config, metadata: TrainingMetadata +) -> int: + def is_last_stage(): + sorted_stages = sorted(config.data_stages, key=lambda x: x.start_training_step) + return sorted_stages[-1].start_training_step == stage.start_training_step + + def is_resume_from_training(): + return metadata.last_train_step > 0 + + if is_last_stage() is True: + total_train_steps = config.tokens.train_steps + else: + next_stage = next((s for s in config.data_stages if s.start_training_step > stage.start_training_step), None) + total_train_steps = next_stage.start_training_step + + if metadata.last_train_step > stage.start_training_step: + # NOTE: if the last_train_step is larger than the start_training_step of the current stage, + # it means that the training has already passed this stage + # so there is no remaining steps + return 0 + else: + last_train_steps = metadata.last_train_step if is_resume_from_training() else stage.start_training_step + return total_train_steps - last_train_steps + + +def get_consumed_train_samples_of_a_data_stage_from_ckp( + stage: DatasetStageArgs, metadata: TrainingMetadata +) -> Optional[int]: + start_training_step = stage.start_training_step + return next( + (s.consumed_train_samples for s in metadata.data_stages if s.start_training_step == start_training_step), + None, + ) diff --git a/src/nanotron/models/base.py b/src/nanotron/models/base.py index ac54e9fb..14ac6908 100644 --- a/src/nanotron/models/base.py +++ b/src/nanotron/models/base.py @@ -1,6 +1,6 @@ from abc import ABCMeta, abstractmethod from contextlib import contextmanager -from typing import TYPE_CHECKING, Callable, Iterator, List, Optional, Tuple +from typing import TYPE_CHECKING, Callable, Dict, Iterator, List, Optional, Tuple import numpy as np import torch @@ -65,13 +65,16 @@ def tie_custom_params(self) -> None: """Tie custom parameters. For example for MQA marks kv heads as tied.""" pass - @staticmethod - def get_embeddings_lm_head_tied_names() -> list[str]: + def get_embeddings_lm_head_tied_names(self) -> list[str]: """Returns the names of the embeddings and lm_head weights that are tied together. Returns empty list if not tied. Example for GPT2 model: ["model.token_position_embeddings.pp_block.token_embedding.weight", "model.lm_head.pp_block.weight"] """ return [] + + def get_named_params_without_weight_decay(self) -> List[str]: + """Return a list of named parameters that should not have weight decay applied to them.""" + return [] def before_tbi_sanity_checks(self) -> None: pass @@ -99,6 +102,39 @@ def log_modules(self, level: int = logging.DEBUG, group: Optional[ProcessGroup] rank=rank, ) + @property + def named_modules_in_pp_rank(self) -> Dict[str, nn.Module]: + """Return the named modules that only belongs to the current pp rank. + + An example output: + { + 'module_name': module, + ... + } + + NOTE: not include module_name.weight or bias, but only module_name + """ + + def get_leaf_modules(module: nn.Module) -> List[Tuple[str, nn.Module]]: + """ + Return all the leaf modules (modules without any child modules) in a PyTorch module. + """ + leaf_modules = [] + for n, m in module.named_modules(): + if not list(m.children()): + leaf_modules.append((n, m)) + return leaf_modules + + modules = get_leaf_modules(self) + named_modules_in_current_pp_rank = {} + for name, module in modules: + if isinstance(module, PipelineBlock): + # NOTE: these are the modules that aren't belong to the current pp rank + continue + named_modules_in_current_pp_rank[name] = module + + return named_modules_in_current_pp_rank + class DTypeInvariantTensor(torch.Tensor): """DTypeInvariantTensor is a subclass of torch.Tensor that disallows modification of its dtype. Note that the data @@ -171,7 +207,7 @@ def build_model( pp_size = len(target_pp_ranks) # Set rank for each pipeline block - log_rank("Setting PP block ranks..", logger=logger, level=logging.INFO, rank=0, group=parallel_context.world_pg) + log_rank("Setting PP block ranks...", logger=logger, level=logging.INFO, rank=0, group=parallel_context.world_pg) pipeline_blocks = [module for name, module in model.named_modules() if isinstance(module, PipelineBlock)] # "cuda" is already defaulted for each process to it's own cuda device with init_on_device_and_dtype(device=device, dtype=dtype): diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index b930e0eb..32aab9cd 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -12,22 +12,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch LLaMa model. -""" -from typing import Dict, Optional, Union -import math +"""PyTorch LLaMa model.""" + +from typing import Dict, Optional, Union, List + import torch -from flash_attn import bert_padding -from flash_attn.flash_attn_interface import ( - flash_attn_varlen_func, - flash_attn_with_kvcache, -) -from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding from torch import nn from nanotron import distributed as dist from nanotron import logging -from nanotron.config import LlamaConfig, ParallelismArgs +from nanotron.config import Config, LlamaConfig, ParallelismArgs +from nanotron.config.models_config import RandomInit, SpectralMupInit from nanotron.generation.generate_store import AttachableStore from nanotron.logging import log_rank from nanotron.models import NanotronModel @@ -35,10 +30,7 @@ from nanotron.nn.layer_norm import TritonRMSNorm from nanotron.parallel import ParallelContext from nanotron.parallel.parameters import NanotronParameter -from nanotron.parallel.pipeline_parallel.block import ( - PipelineBlock, - TensorPointer, -) +from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer from nanotron.parallel.pipeline_parallel.p2p import P2P from nanotron.parallel.tensor_parallel.functional import sharded_cross_entropy from nanotron.parallel.tensor_parallel.nn import ( @@ -48,6 +40,7 @@ TensorParallelRowLinear, ) from nanotron.random import RandomStates +from nanotron.scaling.parametrization import SpectralMupParametrizator, StandardParametrizator from nanotron.utils import checkpoint_method logger = logging.get_logger(__name__) @@ -162,7 +155,6 @@ def __init__( async_communication=tp_linear_async_communication, contiguous_chunks=gate_up_contiguous_chunks, ) - self.down_proj = TensorParallelRowLinear( config.intermediate_size, config.hidden_size, @@ -189,6 +181,7 @@ def __init__(self, config: LlamaConfig, parallel_config: Optional[ParallelismArg ), f"Hidden size {config.hidden_size} must be divisible by number of attention heads {config.num_attention_heads}." self.d_qk = config.hidden_size // config.num_attention_heads self.d_v = config.hidden_size // config.num_attention_heads + self.is_using_mup = config.is_using_mup self.checkpoint_attention = False # Because flash_attn already does checkpointing @@ -201,6 +194,8 @@ def forward( q_sequence_mask: torch.Tensor, # torch.BoolTensor [batch_size, q_length] (can be broadcasted to that size) kv_sequence_mask: torch.Tensor, # torch.BoolTensor [batch_size, kv_length] (can be broadcasted to that size) ): + from flash_attn.flash_attn_interface import flash_attn_varlen_func + # TODO @thomasw21: Compute once, instead of computing for each layers. cu_seqlens_q = torch.zeros((q_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device) cu_seqlens_k = torch.zeros((kv_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device) @@ -210,6 +205,10 @@ def forward( # TODO(kunhao): flash attn's causal means that the query can only attend to the keys before it. This is not # what we want if we are using kv cache. This is a hack as we always have q_length == 1 when using kv cache. causal = False if q_sequence_mask.shape[1] == 1 else True + + # NOTE: this scale is for µTransfer, + # in SP, we use sqrt(1/d_h) + softmax_scale = 1 / query_states.shape[-1] if self.is_using_mup else None attn_output = flash_attn_varlen_func( q=query_states, k=key_states, @@ -219,7 +218,7 @@ def forward( max_seqlen_q=q_sequence_mask.shape[1], max_seqlen_k=kv_sequence_mask.shape[1], dropout_p=0.0, - softmax_scale=None, # This already defaults to the scale I'm interested in + softmax_scale=softmax_scale, causal=causal, return_attn_probs=False, ) @@ -263,6 +262,8 @@ def __init__( tp_pg: dist.ProcessGroup, layer_idx: int, ): + from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding + super().__init__() # Tensor parallel considerations: We split tensors along head dimension assert ( @@ -291,6 +292,7 @@ def __init__( self.d_qk = config.hidden_size // config.num_attention_heads self.d_v = config.hidden_size // config.num_attention_heads self.d_model = config.hidden_size + self.is_using_mup = config.is_using_mup # TODO @thomasw21: refactor so that we store that default in a single place. tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE @@ -347,6 +349,12 @@ def forward( hidden_states, # [seq_length, batch_size, hidden_size] sequence_mask, # [batch_size, seq_length] ): + from flash_attn import bert_padding + from flash_attn.flash_attn_interface import ( + flash_attn_varlen_func, + flash_attn_with_kvcache, + ) + qkv_states = self.qkv_proj( hidden_states ) # [seq_length, batch_size, n_local_q_heads * d_qk + 2 * n_local_kv_heads * d_qk] @@ -433,6 +441,9 @@ def forward( ) (value_unpad, _, _, _) = bert_padding.unpad_input(value_states, sequence_mask) + # NOTE: this scale is for µTransfer, + # in SP, we use sqrt(1/d_h) + softmax_scale = 1 / query_states.shape[-1] if self.is_using_mup else None output_unpad = flash_attn_varlen_func( q=query_unpad, # (total_q, n_local_q_heads, d_qk) k=key_unpad, # (total_kv, n_local_kv_heads, d_qk) @@ -442,7 +453,7 @@ def forward( max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, dropout_p=0.0, - softmax_scale=None, + softmax_scale=softmax_scale, causal=True, # True in prefill phase, False in subsequent phases return_attn_probs=False, ) # (total_unpadded, n_local_q_heads, d_v) @@ -516,6 +527,9 @@ def forward( batch_size, kv_length, self.n_local_kv_heads, self.d_v ) # [batch_size, kv_length, self.n_heads, d_v] + # NOTE: this scale is for µTransfer, + # in SP, we use sqrt(1/d_h) + softmax_scale = 1 / query_states.shape[-1] if self.is_using_mup else None attention_output = flash_attn_with_kvcache( query_states, k_cache, @@ -524,9 +538,9 @@ def forward( value_states, rotary_cos=None, rotary_sin=None, - # TODO @nouamane: seems like this doesnt help to indicate padding in (for first iteration it's just 0) + # TODO @nouamane: seems like this doesn't help to indicate padding in (for first iteration it's just 0) cache_seqlens=position_offsets.contiguous(), - softmax_scale=None, + softmax_scale=softmax_scale, causal=True, rotary_interleaved=False, # GPT-NeoX style ) @@ -829,6 +843,7 @@ def forward( ) -> Dict[str, torch.Tensor]: # Megatron by defaults cast everything in fp32. `--f16-lm-cross-entropy` is an option you can use to keep current precision. # https://github.com/NVIDIA/Megatron-LM/blob/f267e6186eae1d6e2055b412b00e2e545a8e896a/megatron/model/gpt_model.py#L38 + loss = sharded_cross_entropy( sharded_logits, label_ids.transpose(0, 1).contiguous(), group=self.tp_pg, dtype=torch.float ).transpose(0, 1) @@ -881,13 +896,30 @@ def forward( label_mask=label_mask, )["loss"] return {"loss": loss} - + @torch.no_grad() - def init_model_randomly(self, config): + def init_model_randomly(self, config: Config): """Initialize model parameters randomly. Note: Layernorm weight all 0 or 1 depending on `apply_layernorm_1p` """ + init_method = config.model.init_method + if isinstance(init_method, RandomInit): + parametrizator_cls = StandardParametrizator + elif isinstance(init_method, SpectralMupInit): + parametrizator_cls = SpectralMupParametrizator + else: + raise ValueError(f"Unknown init method {init_method}") + + parametrizator = parametrizator_cls(config=config.model) + + log_rank( + f"Parametrizing model parameters using {parametrizator.__class__.__name__}", + logger=logger, + level=logging.INFO, + rank=0, + ) + model = self initialized_parameters = set() # Handle tensor parallelism @@ -895,15 +927,11 @@ def init_model_randomly(self, config): # Fix the root_model module_id_to_prefix[id(model)] = "" - std = config.model.init_method.std - sigma = config.model.init_method.std - num_layers = config.model.model_config.num_hidden_layers - for param_name, param in model.named_parameters(): assert isinstance(param, NanotronParameter) - - module_name, param_name = param_name.rsplit('.', 1) - + + module_name, param_name = param_name.rsplit(".", 1) + if param.is_tied: tied_info = param.get_tied_info() full_param_name = tied_info.get_full_name_from_module_id_to_prefix( @@ -917,37 +945,11 @@ def init_model_randomly(self, config): continue module = model.get_submodule(module_name) - - if isinstance(module, TensorParallelColumnLinear): - if "weight" == param_name: - torch.nn.init.normal_(module.weight, mean=0.0, std=std) - elif "bias" == param_name: - module.bias.zero_() - else: - raise ValueError(f"Who the fuck is {param_name}?") - elif isinstance(module, TensorParallelRowLinear): - if "weight" == param_name: - torch.nn.init.normal_(module.weight, mean=0.0, std=sigma / math.sqrt(2 * num_layers)) - elif "bias" == param_name: - param.zero_() - else: - raise ValueError(f"Who the fuck is {param_name}?") - elif isinstance(module, TritonRMSNorm): - if "weight" == param_name: - # TODO @thomasw21: Sometimes we actually want 0 - module.weight.fill_(1) - elif "bias" == param_name: - module.bias.zero_() - else: - raise ValueError(f"Who the fuck is {param_name}?") - elif isinstance(module, TensorParallelEmbedding): - nn.init.normal_(module.weight, mean=0.0, std=std) - else: - raise Exception(f"Parameter {full_param_name} was not intialized") + parametrizator.parametrize(param_name, module) assert full_param_name not in initialized_parameters initialized_parameters.add(full_param_name) - + assert initialized_parameters == { param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix) if param.is_tied diff --git a/src/nanotron/models/starcoder2.py b/src/nanotron/models/starcoder2.py index 81b5bca6..7100351d 100644 --- a/src/nanotron/models/starcoder2.py +++ b/src/nanotron/models/starcoder2.py @@ -24,15 +24,9 @@ from typing import Dict, List, Optional, Tuple, Union import torch -from flash_attn import bert_padding -from flash_attn.flash_attn_interface import ( - flash_attn_varlen_func, - flash_attn_with_kvcache, -) from torch import nn -from torch.nn import LayerNorm +from torch.nn import LayerNorm, init from torch.nn import functional as F -from torch.nn import init from nanotron import distributed as dist from nanotron.config import ParallelismArgs, Starcoder2Config @@ -63,8 +57,6 @@ from nanotron.random import RandomStates, branch_random_state from nanotron.utils import checkpoint_method -_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_varlen_func).parameters) - def pad_to_right(tensor, mask, new_tensor=None): """Transform a left-padded tensor into a right-padded tensor. (Useful for prefilling key/value states) @@ -228,6 +220,10 @@ class CoreAttention(nn.Module): def __init__(self, config: Starcoder2Config, parallel_config: Optional[ParallelismArgs], layer_idx: int): super().__init__() + from flash_attn.flash_attn_interface import flash_attn_varlen_func + + _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_varlen_func).parameters) + assert ( config.hidden_size % config.num_attention_heads == 0 ), f"Hidden size {config.hidden_size} must be divisible by number of attention heads {config.num_attention_heads}." @@ -258,6 +254,8 @@ def forward( q_sequence_mask: torch.Tensor, # torch.BoolTensor [batch_size, q_length] (can be broadcasted to that size) kv_sequence_mask: torch.Tensor, # torch.BoolTensor [batch_size, kv_length] (can be broadcasted to that size) ): + from flash_attn.flash_attn_interface import flash_attn_varlen_func + # TODO @thomasw21: Compute once, instead of computing for each layers. cu_seqlens_q = torch.zeros((q_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device) cu_seqlens_k = torch.zeros((kv_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device) @@ -682,6 +680,12 @@ def forward( hidden_states, # [seq_length, batch_size, hidden_dim] sequence_mask, # [batch_size, seq_length] ): + from flash_attn import bert_padding + from flash_attn.flash_attn_interface import ( + flash_attn_varlen_func, + flash_attn_with_kvcache, + ) + batch_size = hidden_states.shape[1] def unshape(states): @@ -833,7 +837,7 @@ def shape( value_states, rotary_cos=None, rotary_sin=None, - # TODO @nouamane: seems like this doesnt help to indicate padding in (for first iteration it's just 0) + # TODO @nouamane: seems like this doesn't help to indicate padding in (for first iteration it's just 0) cache_seqlens=position_offsets.contiguous(), softmax_scale=None, causal=True, @@ -956,6 +960,12 @@ def forward( hidden_states, # (seq_length, batch_size, hidden_size) sequence_mask, # (batch_size, seq_length) ): + from flash_attn import bert_padding + from flash_attn.flash_attn_interface import ( + flash_attn_varlen_func, + flash_attn_with_kvcache, + ) + fused_qkv = self.query_key_value( hidden_states ) # [seq_length, batch_size, n_local_q_heads * head_dim + 2 * n_local_kv_heads * head_dim] @@ -1072,7 +1082,7 @@ def forward( value_states, rotary_cos=None, rotary_sin=None, - # TODO @nouamane: seems like this doesnt help to indicate padding in (for first iteration it's just 0) + # TODO @nouamane: seems like this doesn't help to indicate padding in (for first iteration it's just 0) cache_seqlens=position_offsets.contiguous(), softmax_scale=None, causal=True, @@ -1533,7 +1543,7 @@ def init_model_randomly(self, config): elif isinstance(module, TensorParallelEmbedding): nn.init.normal_(module.weight, mean=0.0, std=std) else: - raise Exception(f"Parameter {full_param_name} was not intialized") + raise Exception(f"Parameter {full_param_name} was not initialized") assert full_param_name not in initialized_parameters initialized_parameters.add(full_param_name) @@ -1545,8 +1555,7 @@ def init_model_randomly(self, config): for name, param in model.named_parameters() }, f"Somehow the initialized set of parameters don't match:\n - Expected: { {name for name, _ in model.named_parameters()} }\n - Got: {initialized_parameters}" - @staticmethod - def get_embeddings_lm_head_tied_names() -> List[str]: + def get_embeddings_lm_head_tied_names(self) -> List[str]: return [ "model.token_embeddings.pp_block.token_embedding.weight", "model.lm_head.pp_block.weight", diff --git a/src/nanotron/nn/layer_norm.py b/src/nanotron/nn/layer_norm.py index 04ffa377..688eaa78 100644 --- a/src/nanotron/nn/layer_norm.py +++ b/src/nanotron/nn/layer_norm.py @@ -1,5 +1,4 @@ import torch -from flash_attn.ops.triton.layer_norm import layer_norm_fn from torch import nn @@ -7,6 +6,8 @@ class TritonLayerNorm(nn.LayerNorm): def forward( self, input, residual=None, dropout_p=0.0, prenorm=False, residual_in_fp32=False, return_dropout_mask=False ): + from flash_attn.ops.triton.layer_norm import layer_norm_fn + return layer_norm_fn( input, self.weight, @@ -36,6 +37,8 @@ def reset_parameters(self): def forward( self, input, residual=None, dropout_p=0.0, prenorm=False, residual_in_fp32=False, return_dropout_mask=False ): + from flash_attn.ops.triton.layer_norm import layer_norm_fn + return layer_norm_fn( input, self.weight, diff --git a/src/nanotron/optim/named_optimizer.py b/src/nanotron/optim/named_optimizer.py index 23614b05..74214357 100644 --- a/src/nanotron/optim/named_optimizer.py +++ b/src/nanotron/optim/named_optimizer.py @@ -2,11 +2,14 @@ import torch +from nanotron import logging from nanotron.optim.inherit_from_other_optimizer import InheritFromOtherOptimizer +logger = logging.get_logger(__name__) + class NamedOptimizer(InheritFromOtherOptimizer): - """Mimicks somewhat the torch optimizer API""" + """Mimics somewhat the torch optimizer API""" def __init__( self, diff --git a/src/nanotron/parallel/pipeline_parallel/block.py b/src/nanotron/parallel/pipeline_parallel/block.py index 273f1432..150172f5 100644 --- a/src/nanotron/parallel/pipeline_parallel/block.py +++ b/src/nanotron/parallel/pipeline_parallel/block.py @@ -1,6 +1,8 @@ from typing import Any, Callable, Dict, Optional, Set, Tuple, Union import torch +from torch import nn + from nanotron import distributed as dist from nanotron.parallel.pipeline_parallel.functional import ( recv_from_pipeline_state_buffer, @@ -9,7 +11,6 @@ from nanotron.parallel.pipeline_parallel.p2p import P2P, BatchTensorSendRecvState from nanotron.parallel.pipeline_parallel.state import PipelineBatchState, PipelineTrainBatchState from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer -from torch import nn class PipelineBlock(nn.Module): @@ -19,7 +20,7 @@ class PipelineBlock(nn.Module): - PipelineBlocks have to wrap a method/function/module that outputs a Dict[str, torch.Tensor] Some considerations: - - In the litterature, authors often refer to pipeline stages as a granularity block. Our notion is more granular. A pipeline stage is list of contiguous (in the forward sense) of pipeline blocks. + - In the literature, authors often refer to pipeline stages as a granularity block. Our notion is more granular. A pipeline stage is list of contiguous (in the forward sense) of pipeline blocks. All PipelineBlock definition exist in each rank, they are just instantiated/built on a single rank per pipeline parallel process group. """ diff --git a/src/nanotron/scaling/parametrization.py b/src/nanotron/scaling/parametrization.py new file mode 100644 index 00000000..e6241651 --- /dev/null +++ b/src/nanotron/scaling/parametrization.py @@ -0,0 +1,202 @@ +import math +from abc import abstractmethod +from enum import Enum, auto +from typing import Dict + +from nanotron.config import ModelArgs +from nanotron.nn.layer_norm import TritonRMSNorm +from nanotron.parallel.tensor_parallel.nn import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, +) +from torch import nn +from torch.nn import init + + +class ParametrizationMethod(Enum): + STANDARD = auto() + SPECTRAL_MUP = auto() + + +class Parametrizator: + def __init__(self, config: ModelArgs): + self.config = config + + def parametrize(self, param_name: str, module: nn.Module): + if not isinstance(module, tuple(self.MODULE_TO_PARAMETRIZE.keys())): + raise Exception(f"Parameter {param_name} was not initialized") + + return self.MODULE_TO_PARAMETRIZE[type(module)](param_name, module) + + +class StandardParametrizator(Parametrizator): + def __init__(self, config: ModelArgs): + super().__init__(config) + self.MODULE_TO_PARAMETRIZE = { + TensorParallelColumnLinear: self._parametrize_column_linear, + TensorParallelRowLinear: self._parametrize_row_linear, + TritonRMSNorm: self._parametrize_layer_norm, + TensorParallelEmbedding: self._parametrize_embedding, + } + + self.std = config.init_method.std + self.num_layers = config.model_config.num_hidden_layers + + def _parametrize_column_linear(self, param_name: str, module: nn.Module): + assert param_name in ["weight", "bias"] + + if "weight" == param_name: + init.normal_(module.weight, mean=0.0, std=self.std) + elif "bias" == param_name: + module.bias.zero_() + + def _parametrize_row_linear(self, param_name: str, module: nn.Module): + assert param_name in ["weight", "bias"] + + if "weight" == param_name: + std = self.std / math.sqrt(2 * self.num_layers) + init.normal_(module.weight, mean=0.0, std=std) + elif "bias" == param_name: + module.bias.zero_() + + def _parametrize_layer_norm(self, param_name: str, module: nn.Module): + assert param_name in ["weight", "bias"] + + if "weight" == param_name: + # TODO @thomasw21: Sometimes we actually want 0 + module.weight.fill_(1) + elif "bias" == param_name: + module.bias.zero_() + + def _parametrize_embedding(self, param_name: str, module: nn.Module): + assert param_name in ["weight"] + + if "weight" == param_name: + init.normal_(module.weight, mean=0.0, std=self.std) + + +class SpectralMupParametrizator(Parametrizator): + """ + A Spectral Condition for Feature Learning by Greg Yang, et al. + https://arxiv.org/abs/2310.17813 + """ + + def __init__(self, config: ModelArgs): + super().__init__(config) + self.MODULE_TO_PARAMETRIZE = { + TensorParallelColumnLinear: self._parametrize_mup_weight, + TensorParallelRowLinear: self._parametrize_mup_weight, + TritonRMSNorm: self._parametrize_layer_norm, + TensorParallelEmbedding: self._parametrize_embedding, + } + self.std = 1.0 + + @staticmethod + def _compute_spectral_std(std: float, fan_in: int, fan_out: int): + """ + Parametrization 1 (Spectral parametrization) + Page 8, A Spectral Condition for Feature Learning by Greg Yang, et al. + + σₗ = Θ(1/√nₗ₋₁ min{1, √(nₗ/nₗ₋₁)}) + """ + return (std / math.sqrt(fan_in)) * min(1, math.sqrt(fan_out / fan_in)) + + def _parametrize_mup_weight(self, param_name: str, module: nn.Module): + assert param_name in ["weight", "bias"] + + data = module.weight if param_name == "weight" else module.bias + fan_in, fan_out = init._calculate_fan_in_and_fan_out(data) + world_size = module.world_size + + if isinstance(module, TensorParallelColumnLinear): + fan_out = fan_out * world_size + elif isinstance(module, TensorParallelRowLinear): + fan_in = fan_in * world_size + else: + raise ValueError(f"Unknown module {module}") + + std = SpectralMupParametrizator._compute_spectral_std(std=self.std, fan_in=fan_in, fan_out=fan_out) + init.normal_(data, mean=0.0, std=std) + + def _parametrize_layer_norm(self, param_name: str, module: nn.Module): + assert param_name in ["weight", "bias"] + + # NOTE: you're free to change the initialization of layer norm + # as it's not a part of µTransfer + if "weight" == param_name: + module.weight.fill_(1) + elif "bias" == param_name: + module.bias.zero_() + + def _parametrize_embedding(self, param_name: str, module: nn.Module): + assert param_name in ["weight"] + + # NOTE: you're free to change the initialization of input embedding/lm head + if "weight" == param_name: + init.normal_(module.weight, mean=0.0, std=self.std) + + +class LearningRateForParametrizator: + def __init__(self, lr: float, names_to_modules: Dict[str, nn.Module]): + self.lr = lr + self.names_to_modules = names_to_modules + + @abstractmethod + def get_lr(self, param_name: str, module: nn.Module) -> float: + raise NotImplementedError + + +class LearningRateForSP(LearningRateForParametrizator): + """All parameters get the same learning rate.""" + + def get_lr(self, param_name: str, param: nn.Module) -> float: + return self.lr + + +class LearningRateForSpectralMup(LearningRateForParametrizator): + """ + A Spectral Condition for Feature Learning by Greg Yang, et al. + + NOTE: each parameter gets a custom learning rate based on its fan-in and fan-out. + """ + + def __init__(self, lr: float, names_to_modules: Dict[str, nn.Module]): + super().__init__(lr, names_to_modules) + self.MODULE_TO_PARAMETRIZE = { + TensorParallelColumnLinear: self._get_mup_lr, + TensorParallelRowLinear: self._get_mup_lr, + TritonRMSNorm: self._get_global_lr, + TensorParallelEmbedding: self._get_global_lr, + } + + def _get_mup_lr(self, param: nn.Parameter, module: nn.Module): + """ + Parametrization 1 (Spectral parametrization) + Page 8, A Spectral Condition for Feature Learning by Greg Yang, et al. + + ηₗ = Θ(nₗ/nₗ₋₁) + """ + fan_in, fan_out = init._calculate_fan_in_and_fan_out(param) + world_size = module.world_size + + if isinstance(module, TensorParallelColumnLinear): + fan_out = fan_out * world_size + elif isinstance(module, TensorParallelRowLinear): + fan_in = fan_in * world_size + else: + raise ValueError(f"Unknown module {module}") + + return self.lr * (fan_out / fan_in) + + def _get_global_lr(self, param: nn.Parameter, module: nn.Module) -> float: + return self.lr + + def get_lr(self, param_name: str, param: nn.Parameter) -> float: + """Return the learning rate for the given parameter.""" + # NOTE: param_name should be like 'model.token_position_embeddings.pp_block.token_embedding.weight' + # since names_to_modules map module_name to module + # so we remove the .weight and .bias from param_name to get the module_name + module_name = param_name.rsplit(".", 1)[0] + module = self.names_to_modules[module_name] + return self.MODULE_TO_PARAMETRIZE[type(module)](param, module) diff --git a/src/nanotron/serialize/main.py b/src/nanotron/serialize/main.py index a2a3d4aa..286008ac 100644 --- a/src/nanotron/serialize/main.py +++ b/src/nanotron/serialize/main.py @@ -1,14 +1,16 @@ from pathlib import Path -from typing import Optional +from typing import Optional, cast import torch from torch import nn from torch.nn.parallel import DistributedDataParallel +from torch.optim.lr_scheduler import LambdaLR from nanotron import distributed as dist from nanotron import logging from nanotron import optim as optim from nanotron.config import Config +from nanotron.constants import MODEL_CONFIG_FILE_NAME from nanotron.distributed import get_global_rank from nanotron.logging import log_rank from nanotron.parallel import ParallelContext @@ -17,7 +19,7 @@ assert_tensor_synced_across_pg, check_optim_state_in_sync, ) -from nanotron.serialize.metadata import CheckpointMetadata, load_meta, save_meta +from nanotron.serialize.metadata import CheckpointMetadata, TrainingMetadata, load_meta, save_meta from nanotron.serialize.optimizer import ( load_lr_scheduler, load_optimizer, @@ -51,16 +53,15 @@ def save( optimizer: optim.BaseOptimizer, lr_scheduler: torch.optim.lr_scheduler.LRScheduler, parallel_context: ParallelContext, + training_metadata: TrainingMetadata, root_folder: Path, should_save_config: bool = True, should_save_model: bool = True, should_save_optimizer: bool = True, should_save_lr_scheduler: bool = True, - checkpoint_metadata: dict = None, sanity_checks: bool = True, ) -> None: - if checkpoint_metadata is None: - checkpoint_metadata = {} + assert isinstance(training_metadata, TrainingMetadata) try: if should_save_config: @@ -98,6 +99,11 @@ def save( raise e try: if should_save_lr_scheduler: + lr_scheduler = cast(LambdaLR, lr_scheduler) + assert len(lr_scheduler.lr_lambdas) == len( + optimizer.param_groups + ), "The number of lambdas functions in the scheduler should be equal to the number of parameter groups in the optimizer." + save_lr_scheduler( lr_scheduler=lr_scheduler, parallel_context=parallel_context, @@ -112,7 +118,7 @@ def save( ) raise e - save_meta(root_folder=root_folder, parallel_context=parallel_context, checkpoint_metadata=checkpoint_metadata) + save_meta(root_folder=root_folder, parallel_context=parallel_context, training_metadata=training_metadata) # TODO @thomas21: sanity check, not sure whether that needs to happen at testing or now (depends how much it costs) ### @@ -194,7 +200,6 @@ def save( rtol=0, msg=lambda msg: f"tensor at {current_state_dict['names'][index]} doesn't match with our reference. Optimizer key: {name}\nCur: {tensor}\nRef: {reference_tensor}\n{msg}", ) - ### dist.barrier(parallel_context.world_pg) @@ -256,7 +261,7 @@ def parse_ckpt_path(config: Config) -> Optional[Path]: load_from_candidate = int(fi.read()) checkpoint_path = config.checkpoints.resume_checkpoint_path / str(load_from_candidate) - elif (config.checkpoints.resume_checkpoint_path / "model_config.json").exists(): + elif (config.checkpoints.resume_checkpoint_path / MODEL_CONFIG_FILE_NAME).exists(): # we assume that the checkpoint path is a path to a checkpoint checkpoint_path = config.checkpoints.resume_checkpoint_path diff --git a/src/nanotron/serialize/metadata.py b/src/nanotron/serialize/metadata.py index 0953a522..0d8708f9 100644 --- a/src/nanotron/serialize/metadata.py +++ b/src/nanotron/serialize/metadata.py @@ -1,7 +1,7 @@ import dataclasses import json from pathlib import Path -from typing import Any, Callable, ClassVar, Dict, List, Tuple, Type, Union +from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union import dacite import torch @@ -9,18 +9,62 @@ from packaging.version import Version from nanotron import distributed as dist -from nanotron.constants import CHECKPOINT_VERSION +from nanotron.constants import CHECKPOINT_FILE_NAME, CHECKPOINT_VERSION from nanotron.parallel import ParallelContext from nanotron.parallel.parameters import SlicesPair +@dataclasses.dataclass +class DataStageMetadata: + """ + consumed_train_samples: The number of samples consumed by the model in the this stage (each stage starts from zero). + last_train_step: The last training step across all stages. + + # NOTE: we should allow people to change the name of the data stages in the config file. + # but not the start_training_step, because it could + """ + + name: str + start_training_step: int + consumed_train_samples: int + + +@dataclasses.dataclass +class TrainingMetadata: + """ + consumed_train_samples: The number of samples consumed globally, across all stages. + last_train_step: The last training step across all stages. + last_stage_idx: The index of the last stage that was trained. + data_stages: The metadata for each stage. + """ + + consumed_train_samples: int + last_train_step: int + + # TODO(xrsrke): make this not optional, once we entirely remove + # the old checkpoint version + last_stage_idx: Optional[int] = None + data_stages: Optional[List[DataStageMetadata]] = None + + def __post_init__(self): + # NOTE: this is a sanity check after loading a trained checkpoint + total_consumed_samples_across_stages = sum(stage.consumed_train_samples for stage in self.data_stages) + assert ( + self.consumed_train_samples == total_consumed_samples_across_stages + ), "Mismatch between the total consumed samples and the sum of consumed samples across stages! Something went wrong in the training." + + # TODO(xrsrke): remove this once we entirely remove non-data-stage training + if self.last_stage_idx is not None: + assert self.data_stages is not None, "data_stages should not be None if last_stage_idx is not None" + + @dataclasses.dataclass class CheckpointMetadata: version: Version tp: int dp: int - # Anything users want to store - metas: Dict + metas: TrainingMetadata + custom_metas: Optional[Dict[str, Any]] = None @dataclasses.dataclass @@ -81,7 +125,9 @@ def to_list(list_: Union[List, Tuple], type_hooks: Dict[Type, Callable[[Any], An return list_.__class__((process_type(elt, type_hooks=type_hooks) for elt in list_)) -def save_meta(parallel_context: ParallelContext, root_folder: Path, checkpoint_metadata: dict): +def save_meta(parallel_context: ParallelContext, root_folder: Path, training_metadata: TrainingMetadata): + assert isinstance(training_metadata, TrainingMetadata) + if dist.get_rank(parallel_context.world_pg) != 0: return @@ -90,18 +136,18 @@ def save_meta(parallel_context: ParallelContext, root_folder: Path, checkpoint_m version=CHECKPOINT_VERSION, tp=parallel_context.tp_pg.size(), dp=parallel_context.dp_pg.size(), - metas=checkpoint_metadata, + metas=training_metadata, ) # There are some types that require manual casting in order to work correctly. processed_metadata = process_type(dataclasses.asdict(checkpoint_metadata), type_hooks={Version: lambda x: str(x)}) - with open(root_folder / "checkpoint_metadata.json", mode="w") as fo: + with open(root_folder / CHECKPOINT_FILE_NAME, mode="w") as fo: json.dump(processed_metadata, fo, indent=2, sort_keys=True) def load_meta(parallel_context: ParallelContext, root_folder: Path) -> CheckpointMetadata: - with open(root_folder / "checkpoint_metadata.json", mode="r") as fi: + with open(root_folder / CHECKPOINT_FILE_NAME, mode="r") as fi: checkpoint_metadata = json.load(fi) checkpoint_metadata = from_dict( data_class=CheckpointMetadata, diff --git a/src/nanotron/serialize/weights.py b/src/nanotron/serialize/weights.py index c857154f..9a291d38 100644 --- a/src/nanotron/serialize/weights.py +++ b/src/nanotron/serialize/weights.py @@ -4,7 +4,6 @@ import dacite import torch from packaging.version import Version -from safetensors import SafetensorError from safetensors.torch import safe_open, save_file from torch import nn from tqdm import tqdm @@ -97,11 +96,6 @@ def save_weights(model: nn.Module, parallel_context: ParallelContext, root_folde path.parent.mkdir(exist_ok=True, parents=True) try: tensors = {"data": param_or_buffer} - - # Mamba has some parameters that should not be weight decayed - if hasattr(model.get_parameter(name), "_no_weight_decay"): - tensors.update({"_no_weight_decay": torch.tensor(model.get_parameter(name)._no_weight_decay)}) - save_file(tensors=tensors, filename=path, metadata=metadata) except Exception as e: log_rank( @@ -268,12 +262,6 @@ def load_weights( # TODO @thomasw21: Choose only a slice if we switch the TP topology param_or_buffer[:] = fi.get_tensor("data") - # Only Mamba params has this attribute - try: - param._no_weight_decay = fi.get_tensor("_no_weight_decay") - except SafetensorError: - pass - elif not path.parent.exists(): raise ValueError( f"Checkpoint is empty or checkpoint structure is not matching the model architecture." diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 4abf3722..b23b99b3 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -31,11 +31,15 @@ ExistingCheckpointInit, ParallelismArgs, RandomInit, + SpectralMupInit, get_config_from_file, ) +from nanotron.constants import MODEL_CONFIG_FILE_NAME from nanotron.dataloader import sanity_check_dataloader from nanotron.helpers import ( _vocab_size_with_padding, + compute_remain_train_steps_of_a_data_stage_from_ckp, + get_consumed_train_samples_of_a_data_stage_from_ckp, get_profiler, init_optimizer_and_grad_accumulator, init_random_states, @@ -78,6 +82,7 @@ before_optim_step_sanity_checks, before_tbi_sanity_checks, ) +from nanotron.scaling.parametrization import ParametrizationMethod from nanotron.serialize import ( load_lr_scheduler, load_meta, @@ -86,6 +91,7 @@ save, save_random_states, ) +from nanotron.serialize.metadata import DataStageMetadata, TrainingMetadata from nanotron.serialize.optimizer import load_optimizer logger = logging.get_logger(__name__) @@ -168,9 +174,19 @@ def __init__( self.model.module if isinstance(self.model, DistributedDataParallel) else self.model ) + # TODO: find a better way to handle this + parametrization_method = ( + ParametrizationMethod.SPECTRAL_MUP + if hasattr(self.config.model.init_method, "use_mup") and self.config.model.init_method.use_mup + else ParametrizationMethod.STANDARD + ) + # Init optimizer self.optimizer, self.grad_accumulator = init_optimizer_and_grad_accumulator( - model=self.model, optimizer_args=self.config.optimizer, parallel_context=self.parallel_context + parametrization_method=parametrization_method, + model=self.model, + optimizer_args=self.config.optimizer, + parallel_context=self.parallel_context, ) if self.init_checkpoint_path is not None: load_optimizer( @@ -198,15 +214,23 @@ def __init__( checkpoint_metadata = load_meta( parallel_context=self.parallel_context, root_folder=self.init_checkpoint_path ) + assert isinstance(checkpoint_metadata.metas, TrainingMetadata) log_rank(str(checkpoint_metadata), logger=logger, level=logging.INFO, rank=0) - self.start_iteration_step = checkpoint_metadata.metas["last_train_step"] - self.consumed_train_samples = checkpoint_metadata.metas["consumed_train_samples"] + self.metadata: TrainingMetadata = checkpoint_metadata.metas + # NOTE: we should not change data stages assert ( - self.config.tokens.train_steps > self.start_iteration_step - ), f"Loaded checkpoint has already trained {self.start_iteration_step} batches, you need to specify a higher `config.tokens.train_steps`" + self.config.tokens.train_steps > self.metadata.last_train_step + ), f"Loaded checkpoint has already trained {self.metadata.last_train_step} batches, you need to specify a higher `config.tokens.train_steps`" else: - self.start_iteration_step = 0 - self.consumed_train_samples = 0 + data_stages = [ + DataStageMetadata( + name=stage.name, start_training_step=stage.start_training_step, consumed_train_samples=0 + ) + for stage in self.config.data_stages + ] + self.metadata: TrainingMetadata = TrainingMetadata( + consumed_train_samples=0, last_train_step=0, last_stage_idx=0, data_stages=data_stages + ) # Setup tensorboard write and log writers on output rank self.logger_ranks = self.parallel_context.get_global_rank( @@ -223,7 +247,7 @@ def __init__( self.micro_batch_size * self.n_micro_batches_per_batch * self.parallel_context.dp_pg.size() ) self.sequence_length = self.config.tokens.sequence_length - self.iteration_step = self.start_iteration_step + self.iteration_step = self.metadata.last_train_step self.limit_val_batches = self.config.tokens.limit_val_batches # NOTE: the dataloader currently in use for the current training stage self.current_dataloader: Optional[DataLoader] = None @@ -239,8 +263,10 @@ def post_init(self): def pre_training(self, *args, **kwargs): self._print_training_plan() + metadata: TrainingMetadata = self.metadata + log_rank( - f"[Start training] datetime: {datetime.datetime.now()} | mbs: {self.micro_batch_size} | grad_accum: {self.n_micro_batches_per_batch} | global_batch_size: {self.global_batch_size} | sequence_length: {self.sequence_length} | train_steps: {self.config.tokens.train_steps} | start_iteration_step: {self.start_iteration_step} | consumed_train_samples: {self.consumed_train_samples}", # noqa + f"[Start training] datetime: {datetime.datetime.now()} | mbs: {self.micro_batch_size} | grad_accum: {self.n_micro_batches_per_batch} | global_batch_size: {self.global_batch_size} | sequence_length: {self.sequence_length} | train_steps: {self.config.tokens.train_steps} | start_iteration_step: {metadata.last_train_step} | consumed_train_samples: {metadata.consumed_train_samples}", # noqa logger=logger, level=logging.INFO, rank=0, @@ -312,23 +338,45 @@ def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str): gc.collect() dataloader = None - for stage_id, stage in enumerate(self.config.data_stages): + + def find_stage_idx_to_resume(): + reversed_data_stages = sorted(self.config.data_stages, key=lambda x: x.start_training_step, reverse=True) + for idx, stage in enumerate(reversed_data_stages): + if self.iteration_step >= stage.start_training_step: + return len(self.config.data_stages) - idx - 1 + return None + + stage_idx_to_resume = find_stage_idx_to_resume() + + for stage_idx, stage in enumerate(self.config.data_stages): + if stage_idx < self.metadata.last_stage_idx: + continue + stage = cast(DatasetStageArgs, stage) - if stage.start_training_step == self.iteration_step: + is_resume_from_training = self.current_dataloader is None and stage_idx_to_resume == stage_idx + if (stage.start_training_step == self.iteration_step) or is_resume_from_training: if self.current_dataloader is not None: - prev_stage_name = self.config.data_stages[stage_id - 1].name + prev_stage_name = self.config.data_stages[stage_idx - 1].name prev_dataloader = dataloaders[prev_stage_name] + if isinstance(prev_dataloader, DataLoader): # NOTE: we don't need to clear dummy data generator from memory clear_dataloader_from_memory(prev_dataloader, stage_name=stage.name) - log_rank( - f"[Training Stage: {stage.name}] Switching to a new dataset {stage.data.dataset.hf_dataset_or_datasets}", - logger=logger, - level=logging.INFO, - rank=0, - ) + self.metadata.last_stage_idx = stage_idx + + if is_resume_from_training: + remaining_train_steps = compute_remain_train_steps_of_a_data_stage_from_ckp( + stage, self.config, self.metadata + ) + consumed_train_steps = get_consumed_train_samples_of_a_data_stage_from_ckp(stage, self.metadata) + log_rank( + f"Resuming training from stage {stage.name}, it has trained for {consumed_train_steps} samples and has {remaining_train_steps} remaining train steps", + logger=logger, + level=logging.INFO, + rank=0, + ) dataloader = dataloaders[stage.name] # NOTE: if a dataloader is lazy initialized, we need to call it to initialize it @@ -368,7 +416,7 @@ def train( prof = get_profiler(config=self.config) torch.cuda.empty_cache() with prof: - for self.iteration_step in range(self.start_iteration_step + 1, self.config.tokens.train_steps + 1): + for self.iteration_step in range(self.metadata.last_train_step + 1, self.config.tokens.train_steps + 1): if isinstance(prof, torch.profiler.profile): prof.step() @@ -379,7 +427,12 @@ def train( outputs, loss_avg = self.training_step(dataloader=self.current_dataloader) # Training Logs - self.consumed_train_samples += self.global_batch_size + # TODO(xrsrke): refactor using callbacks would be better + self.metadata.consumed_train_samples += self.global_batch_size + self.metadata.last_train_step = self.iteration_step + self.metadata.data_stages[ + self.metadata.last_stage_idx + ].consumed_train_samples += self.global_batch_size if (self.iteration_step - 1) % self.config.logging.iteration_step_info_interval == 0: self.train_step_logs(outputs=outputs, loss_avg=loss_avg) @@ -520,7 +573,9 @@ def train_step_logs( log_entries = [ # LogItem("consumed_samples", self.consumed_train_samples, "human_format"), # , "12d"), LogItem( - "consumed_tokens", self.consumed_train_samples * self.config.tokens.sequence_length, "human_format" + "consumed_tokens", + self.metadata.consumed_train_samples * self.config.tokens.sequence_length, + "human_format", ), # , "12d"), LogItem("elapsed_time_per_iteration_ms", elapsed_time_per_iteration_ms, "human_format"), # , ".1f"), LogItem("tokens_per_sec", tokens_per_sec, "human_format"), # , "1.6E"), @@ -654,13 +709,12 @@ def _load_model_checkpoint(self, model: NanotronModel) -> NanotronModel: parallel_context=self.parallel_context, root_folder=self.config.model.init_method.path, ) - elif isinstance(self.config.model.init_method, RandomInit): - + elif isinstance(self.config.model.init_method, (RandomInit, SpectralMupInit)): unwrapped_model.init_model_randomly(config=self.config) # Synchronize parameters so that the model is consistent # sync all params across dp - for name, param in sorted(model.named_parameters(), key=lambda x: x[0]): + for _, param in sorted(model.named_parameters(), key=lambda x: x[0]): dist.all_reduce(param, op=dist.ReduceOp.AVG, group=self.parallel_context.dp_pg) # sync tied params across tied groups @@ -797,15 +851,10 @@ def save_checkpoint(self) -> Path: dist.barrier(self.parallel_context.world_pg) log_rank(f"Saving checkpoint at {checkpoint_path}", logger=logger, level=logging.WARNING, rank=0) - checkpoint_metadata = { - "last_train_step": self.iteration_step, - # TODO: @nouamanetazi: Add more metadata to the checkpoint to be able to resume dataloader states properly - "consumed_train_samples": self.consumed_train_samples, - } # Update step/samples numbers before we save the config - self.config.general.step = self.iteration_step - self.config.general.consumed_train_samples = self.consumed_train_samples + self.config.general.step = self.metadata.last_train_step + self.config.general.consumed_train_samples = self.metadata.consumed_train_samples save( model=self.unwrapped_model, @@ -823,7 +872,7 @@ def save_checkpoint(self) -> Path: ), # We only save the config on world_rank==0 parallel_context=self.parallel_context, root_folder=checkpoint_path, - checkpoint_metadata=checkpoint_metadata, + training_metadata=self.metadata, config=self.config, ) save_random_states( @@ -833,9 +882,9 @@ def save_checkpoint(self) -> Path: fo.write(f"{self.iteration_step}") if hasattr(self.model_config, "to_json_file"): - self.model_config.to_json_file(checkpoint_path / "model_config.json") + self.model_config.to_json_file(checkpoint_path / MODEL_CONFIG_FILE_NAME) else: - with open(checkpoint_path / "model_config.json", mode="w") as fo: + with open(checkpoint_path / MODEL_CONFIG_FILE_NAME, mode="w") as fo: fo.write(json.dumps(asdict(self.model_config))) self.post_save_checkpoint() diff --git a/tests/helpers/llama.py b/tests/helpers/llama.py new file mode 100644 index 00000000..3f94031f --- /dev/null +++ b/tests/helpers/llama.py @@ -0,0 +1,137 @@ +import torch +from nanotron.config import ( + AllForwardAllBackwardPipelineEngine, + CheckpointsArgs, + Config, + DataArgs, + DatasetStageArgs, + GeneralArgs, + LlamaConfig, + LoggingArgs, + LRSchedulerArgs, + ModelArgs, + OptimizerArgs, + ParallelismArgs, + TensorParallelLinearMode, + TokenizerArgs, + TokensArgs, +) +from nanotron.config.config import PretrainDatasetsArgs +from nanotron.models import build_model +from nanotron.models.llama import LlamaForTraining +from nanotron.parallel.context import ParallelContext +from nanotron.trainer import mark_tied_parameters + +TINY_LLAMA_CONFIG = LlamaConfig( + **{ + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 16, + "initializer_range": 0.02, + "intermediate_size": 32, + "is_llama_config": True, + "max_position_embeddings": 128, + "num_attention_heads": 8, + "num_hidden_layers": 4, + "num_key_value_heads": 4, + "pad_token_id": None, + "pretraining_tp": 1, + "rms_norm_eps": 1e-06, + "rope_scaling": None, + "tie_word_embeddings": False, + "use_cache": True, + "vocab_size": 4096, + } +) + + +def get_llama_training_config(model_config: ModelArgs): + return Config( + model=model_config, + general=GeneralArgs(project="unittest", run="sanity_llama", seed=42), + checkpoints=CheckpointsArgs( + checkpoints_path="./checkpoints", + checkpoint_interval=10, + ), + parallelism=ParallelismArgs( + dp=1, + pp=1, + tp=2, + expert_parallel_size=2, + pp_engine="1f1b", + tp_mode="ALL_REDUCE", + tp_linear_async_communication=False, + ), + tokenizer=TokenizerArgs("gpt2"), + optimizer=OptimizerArgs( + zero_stage=0, + weight_decay=0.01, + clip_grad=1.0, + accumulate_grad_in_fp32=False, + adam_eps=1e-08, + adam_beta1=0.9, + adam_beta2=0.95, + torch_adam_is_fused=True, + learning_rate_scheduler=LRSchedulerArgs( + learning_rate=3e-4, + lr_warmup_steps=100, + lr_warmup_style="linear", + lr_decay_style="cosine", + min_decay_lr=1e-5, + ), + ), + logging=LoggingArgs(), + tokens=TokensArgs(sequence_length=16, train_steps=10, micro_batch_size=16, batch_accumulation_per_replica=1), + data_stages=[ + DatasetStageArgs( + name="train", + start_training_step=1, + data=DataArgs( + seed=42, + num_loading_workers=1, + dataset=PretrainDatasetsArgs( + hf_dataset_or_datasets="HuggingFaceH4/testing_alpaca_small", + hf_dataset_splits="train", + text_column_name="completion", + dataset_processing_num_proc_per_process=12, + ), + ), + ) + ], + ) + + +def create_llama_from_config( + model_config: LlamaConfig, device: torch.device, parallel_context: ParallelContext +) -> LlamaForTraining: + + """ + Creates and returns a nanotron model. + If `model_config` is None, then `checkpoint_path` must be set, in which case + the configuration will be loaded from such path. + If `checkpoint_path` is None, then `model_config` must be set, in which case + the model created will have random weights. + """ + + parallel_config = ParallelismArgs( + dp=parallel_context.data_parallel_size, + pp=parallel_context.pipeline_parallel_size, + tp=parallel_context.tensor_parallel_size, + pp_engine=AllForwardAllBackwardPipelineEngine(), + tp_mode=TensorParallelLinearMode.ALL_REDUCE, + tp_linear_async_communication=False, + ) + model = build_model( + model_builder=lambda: LlamaForTraining( + config=model_config, + parallel_context=parallel_context, + parallel_config=parallel_config, + random_states=None, + ), + parallel_context=parallel_context, + dtype=torch.bfloat16, + device=device, + ) + mark_tied_parameters(model=model, parallel_context=parallel_context) + return model diff --git a/tests/test_base_model.py b/tests/test_base_model.py new file mode 100644 index 00000000..b4759905 --- /dev/null +++ b/tests/test_base_model.py @@ -0,0 +1,45 @@ +import pytest +import torch +import torch.distributed as dist +from helpers.llama import TINY_LLAMA_CONFIG, create_llama_from_config, get_llama_training_config +from helpers.utils import init_distributed, rerun_if_address_is_in_use +from nanotron.config import Config, ModelArgs, RandomInit +from nanotron.parallel import ParallelContext +from nanotron.parallel.pipeline_parallel.block import PipelineBlock +from torch import nn + + +@pytest.mark.parametrize("tp,dp,pp", [(1, 1, 1), (2, 2, 2)]) +@pytest.mark.skip +@rerun_if_address_is_in_use() +def test_get_named_modules_in_pp_rank(tp: int, dp: int, pp: int): + model_args = ModelArgs(init_method=RandomInit(std=1.0), model_config=TINY_LLAMA_CONFIG) + config = get_llama_training_config(model_args) + + init_distributed(tp=tp, dp=dp, pp=pp)(_test_get_named_modules_in_pp_rank)(config=config) + + +def _test_get_named_modules_in_pp_rank( + parallel_context: ParallelContext, + config: Config, +): + model = create_llama_from_config( + model_config=config.model.model_config, + device=torch.device("cuda"), + parallel_context=parallel_context, + ) + model.init_model_randomly(config=config) + + modules_that_not_in_current_pp_rank = {} + current_pp_rank = dist.get_rank(group=parallel_context.pp_pg) + for name, module in model.named_modules(): + if isinstance(module, PipelineBlock) and module.rank != current_pp_rank: + modules_that_not_in_current_pp_rank[name] = module + + named_modules_in_pp_rank = model.named_modules_in_pp_rank + + for name, module in named_modules_in_pp_rank.items(): + # NOTE: if a module is in the current rank, we expect it to be an initialized module + # not PipelineBlock + assert isinstance(module, nn.Module) + assert name not in modules_that_not_in_current_pp_rank diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py new file mode 100644 index 00000000..1a28f967 --- /dev/null +++ b/tests/test_optimizer.py @@ -0,0 +1,63 @@ +from typing import Union + +import pytest +import torch +from helpers.llama import TINY_LLAMA_CONFIG, create_llama_from_config, get_llama_training_config +from helpers.utils import init_distributed, rerun_if_address_is_in_use +from nanotron.config import ModelArgs, RandomInit, SpectralMupInit +from nanotron.helpers import get_custom_lr_for_named_parameters +from nanotron.parallel import ParallelContext +from nanotron.scaling.parametrization import ParametrizationMethod + + +@pytest.mark.parametrize("tp,dp,pp", [(1, 1, 1), (2, 1, 1), (1, 1, 2), (2, 1, 2)]) +@pytest.mark.parametrize( + "parametrization_method", [ParametrizationMethod.STANDARD, ParametrizationMethod.SPECTRAL_MUP] +) +@pytest.mark.skip +@rerun_if_address_is_in_use() +def test_get_custom_lr(tp: int, dp: int, pp: int, parametrization_method: ParametrizationMethod): + LR = 1e-3 + + if parametrization_method == ParametrizationMethod.STANDARD: + init_method = RandomInit(std=1.0) + elif parametrization_method == ParametrizationMethod.SPECTRAL_MUP: + init_method = SpectralMupInit(use_mup=True) + + init_distributed(tp=tp, dp=dp, pp=pp)(_test_get_custom_lr)( + lr=LR, + init_method=init_method, + parametrization_method=parametrization_method, + ) + + +def _test_get_custom_lr( + parallel_context: ParallelContext, + lr: float, + init_method: Union[RandomInit, SpectralMupInit], + parametrization_method: ParametrizationMethod, +): + model_args = ModelArgs(init_method=init_method, model_config=TINY_LLAMA_CONFIG) + config = get_llama_training_config(model_args) + llama = create_llama_from_config( + model_config=TINY_LLAMA_CONFIG, + device=torch.device("cuda"), + parallel_context=parallel_context, + ) + llama.init_model_randomly(config=config, init_method=parametrization_method) + named_parameters = list(llama.get_named_params_with_correct_tied()) + + if len(named_parameters) == 0: + # NOTE: some pp ranks don't have any parameters + return + + named_param_groups = get_custom_lr_for_named_parameters( + parametrization_method=parametrization_method, lr=lr, named_parameters=named_parameters, model=llama + ) + + assert len(named_param_groups) == len(named_parameters) + assert all(isinstance(named_param_group["lr"], float) for named_param_group in named_param_groups) + assert all(isinstance(named_param_group["named_params"], list) for named_param_group in named_param_groups) + + is_all_lr_the_same = parametrization_method == ParametrizationMethod.STANDARD + assert all(named_param_group["lr"] == lr for named_param_group in named_param_groups) is is_all_lr_the_same diff --git a/tests/test_optimizer_params_groups.py b/tests/test_optimizer_params_groups.py new file mode 100644 index 00000000..fa835e1c --- /dev/null +++ b/tests/test_optimizer_params_groups.py @@ -0,0 +1,581 @@ +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from helpers.utils import available_gpus, init_distributed, rerun_if_address_is_in_use +from nanotron.optim.gradient_accumulator import FP32GradientAccumulator +from nanotron.optim.named_optimizer import NamedOptimizer +from nanotron.optim.optimizer_from_gradient_accumulator import OptimizerFromGradientAccumulator +from nanotron.parallel.context import ParallelContext +from nanotron.parallel.parameters import NanotronParameter +from nanotron.random import set_random_seed + + +class DummyModel(nn.Module): + def __init__(self, dtype=torch.float32): + super(DummyModel, self).__init__() + self.fc1 = nn.Linear(10, 20, bias=False).to(dtype=dtype) + self.fc2 = nn.Linear(20, 2, bias=False).to(dtype=dtype) + + def forward(self, x): + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + return x + + +def test_optimizer_lr_one_group(): + set_random_seed(42) + + model = DummyModel().to("cuda") + + lr1 = 0.1 + + named_params_or_groups = [] + for name, param in model.named_parameters(): + named_params_or_groups.append((name, param)) + named_params_or_groups = [{"named_params": named_params_or_groups, "lr": lr1}] + + optimizer = NamedOptimizer( + named_params_or_groups=named_params_or_groups, + optimizer_builder=lambda param_groups: optim.SGD( + param_groups, + lr=9999999, # this is a dummy value that should be overwritten by the lr in the named_params_or_groups + ), + ) + + input = torch.randn(10, 10).to(device="cuda") + target = torch.randint(0, 2, (10,)).to(device="cuda") + + for _ in range(100): + optimizer.zero_grad() + + output = model(input) + loss = F.cross_entropy(output, target) + loss.backward() + + fc1_grad = model.fc1.weight.grad.clone() + fc2_grad = model.fc2.weight.grad.clone() + + # compute gradient manually + with torch.no_grad(): + expected_fc1_weight = model.fc1.weight - lr1 * fc1_grad + expected_fc2_weight = model.fc2.weight - lr1 * fc2_grad + + optimizer.step() + + updated_fc1_weight = model.fc1.weight + updated_fc2_weight = model.fc2.weight + + torch.testing.assert_close(expected_fc1_weight, updated_fc1_weight) + torch.testing.assert_close(expected_fc2_weight, updated_fc2_weight) + + +def test_optimizer_lr_multiple_group(): + set_random_seed(42) + + model = DummyModel().to("cuda") + + lr1, lr2 = 0.1, 0.001 + + named_params_or_groups = [ + {"named_params": [(name, param) for name, param in model.named_parameters() if "fc1" in name], "lr": lr1}, + {"named_params": [(name, param) for name, param in model.named_parameters() if "fc2" in name], "lr": lr2}, + ] + + optimizer = NamedOptimizer( + named_params_or_groups=named_params_or_groups, + optimizer_builder=lambda param_groups: optim.SGD( + param_groups, + lr=9999999, # this is a dummy value that should be overwritten by the lr in the named_params_or_groups + ), + ) + + input = torch.randn(10, 10).to(device="cuda") + target = torch.randint(0, 2, (10,)).to(device="cuda") + + for _ in range(100): + optimizer.zero_grad() + + output = model(input) + loss = F.cross_entropy(output, target) + loss.backward() + + fc1_grad = model.fc1.weight.grad.clone() + fc2_grad = model.fc2.weight.grad.clone() + + with torch.no_grad(): + expected_fc1_weight = model.fc1.weight - lr1 * fc1_grad + expected_fc2_weight = model.fc2.weight - lr2 * fc2_grad + + optimizer.step() + + updated_fc1_weight = model.fc1.weight + updated_fc2_weight = model.fc2.weight + + torch.testing.assert_close(expected_fc1_weight, updated_fc1_weight) + torch.testing.assert_close(expected_fc2_weight, updated_fc2_weight) + + +def test_optimizer_lr_weight_decay_one_group(): + set_random_seed(42) + + model = DummyModel().to("cuda") + + lr1 = 0.1 + weight_decay = 0.1 + + named_params_or_groups = [] + for name, param in model.named_parameters(): + named_params_or_groups.append((name, param)) + named_params_or_groups = [{"named_params": named_params_or_groups, "lr": lr1, "weight_decay": weight_decay}] + + optimizer = NamedOptimizer( + named_params_or_groups=named_params_or_groups, + optimizer_builder=lambda param_groups: optim.SGD( + param_groups, + lr=9999999, # this is a dummy value that should be overwritten by the lr in the named_params_or_groups + ), + ) + + input = torch.randn(10, 10).to(device="cuda") + target = torch.randint(0, 2, (10,)).to(device="cuda") + + for _ in range(100): + optimizer.zero_grad() + + output = model(input) + loss = F.cross_entropy(output, target) + loss.backward() + + # Compute gradient manually and apply weight decay + with torch.no_grad(): + expected_fc1_weight = (1 - lr1 * weight_decay) * model.fc1.weight - lr1 * model.fc1.weight.grad + expected_fc2_weight = (1 - lr1 * weight_decay) * model.fc2.weight - lr1 * model.fc2.weight.grad + + optimizer.step() + + updated_fc1_weight = model.fc1.weight + updated_fc2_weight = model.fc2.weight + + torch.testing.assert_close(expected_fc1_weight, updated_fc1_weight) + torch.testing.assert_close(expected_fc2_weight, updated_fc2_weight) + + +def test_optimizer_lr_weight_decay_multiple_group(): + set_random_seed(42) + + model = DummyModel().to("cuda") + + lr1, lr2 = 0.1, 0.001 + weight_decay1, weight_decay2 = 0.1, 0.001 + + named_params_or_groups = [ + { + "named_params": [(name, param) for name, param in model.named_parameters() if "fc1" in name], + "lr": lr1, + "weight_decay": weight_decay1, + }, + { + "named_params": [(name, param) for name, param in model.named_parameters() if "fc2" in name], + "lr": lr2, + "weight_decay": weight_decay2, + }, + ] + + optimizer = NamedOptimizer( + named_params_or_groups=named_params_or_groups, + optimizer_builder=lambda param_groups: optim.SGD( + param_groups, + lr=9999999, # this is a dummy value that should be overwritten by the lr in the named_params_or_groups + ), + ) + + input = torch.randn(10, 10).to(device="cuda") + target = torch.randint(0, 2, (10,)).to(device="cuda") + + for _ in range(100): + optimizer.zero_grad() + + output = model(input) + loss = F.cross_entropy(output, target) + loss.backward() + + # Compute gradient manually and apply weight decay + with torch.no_grad(): + expected_fc1_weight = (1 - lr1 * weight_decay1) * model.fc1.weight - lr1 * model.fc1.weight.grad + expected_fc2_weight = (1 - lr2 * weight_decay2) * model.fc2.weight - lr2 * model.fc2.weight.grad + + optimizer.step() + + updated_fc1_weight = model.fc1.weight + updated_fc2_weight = model.fc2.weight + + torch.testing.assert_close(expected_fc1_weight, updated_fc1_weight) + torch.testing.assert_close(expected_fc2_weight, updated_fc2_weight) + + +@pytest.mark.parametrize("half_precision", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("accumulation_steps", [1, 10]) +def test_optimizer_grad_accumulation_lr_one_group(half_precision: torch.dtype, accumulation_steps: int): + set_random_seed(42) + dtype = half_precision + lr1 = 0.1 + + model = DummyModel(dtype=dtype).to("cuda") + + # Need to convert the weights to NanotronParameter for the gradient accumulation to work + model.fc1.weight = NanotronParameter(model.fc1.weight) + model.fc2.weight = NanotronParameter(model.fc2.weight) + + named_params_or_groups = [] + for name, param in model.named_parameters(): + named_params_or_groups.append((name, param)) + + named_params_or_groups = [{"named_params": named_params_or_groups, "lr": lr1}] + + # Optimizer + def optimizer_builder(inp_param_groups): + return NamedOptimizer( + named_params_or_groups=inp_param_groups, + optimizer_builder=lambda param_groups: optim.SGD( + param_groups, + lr=9999999, # this is a dummy value that should be overwritten by the lr in the named_params_or_groups + ), + ) + + optimizer = OptimizerFromGradientAccumulator( + gradient_accumulator_builder=lambda named_params: FP32GradientAccumulator(named_parameters=named_params), + named_params_or_groups=named_params_or_groups, + optimizer_builder=optimizer_builder, + ) + + accumulator = optimizer.gradient_accumulator + + input = torch.randn(10, 10, dtype=dtype).to(device="cuda") + target = torch.randint(0, 2, (10,)).to(device="cuda") + + for batch_idx in range(100): + optimizer.zero_grad() + + output = model(input) + loss = F.cross_entropy(output.float(), target) + accumulator.backward(loss) + + if (batch_idx + 1) % accumulation_steps == 0: + + # Manual update weights for ref + with torch.no_grad(): + fc1_grad = accumulator.get_grad_buffer(name="fc1.weight").to(dtype) + expected_fc1_weight = model.fc1.weight - lr1 * fc1_grad + + fc2_grad = accumulator.get_grad_buffer(name="fc2.weight").to(dtype) + expected_fc2_weight = model.fc2.weight - lr1 * fc2_grad + + optimizer.step() + + updated_fc1_weight = model.fc1.weight + updated_fc2_weight = model.fc2.weight + + torch.testing.assert_close(expected_fc1_weight, updated_fc1_weight) + torch.testing.assert_close(expected_fc2_weight, updated_fc2_weight) + + +@pytest.mark.parametrize("half_precision", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("accumulation_steps", [1, 10]) +def test_optimizer_grad_accumulation_lr_multiple_group(half_precision: torch.dtype, accumulation_steps: int): + set_random_seed(42) + dtype = half_precision + lr1, lr2 = 0.1, 0.001 + + model = DummyModel(dtype=dtype).to("cuda") + + # Need to convert the weights to NanotronParameter for the gradient accumulation to work + model.fc1.weight = NanotronParameter(model.fc1.weight) + model.fc2.weight = NanotronParameter(model.fc2.weight) + + named_params_or_groups = [ + {"named_params": [(name, param) for name, param in model.named_parameters() if "fc1" in name], "lr": lr1}, + {"named_params": [(name, param) for name, param in model.named_parameters() if "fc2" in name], "lr": lr2}, + ] + + # Optimizer + def optimizer_builder(inp_param_groups): + return NamedOptimizer( + named_params_or_groups=inp_param_groups, + optimizer_builder=lambda param_groups: optim.SGD( + param_groups, + lr=9999999, # this is a dummy value that should be overwritten by the lr in the named_params_or_groups + ), + ) + + optimizer = OptimizerFromGradientAccumulator( + gradient_accumulator_builder=lambda named_params: FP32GradientAccumulator(named_parameters=named_params), + named_params_or_groups=named_params_or_groups, + optimizer_builder=optimizer_builder, + ) + + accumulator = optimizer.gradient_accumulator + + input = torch.randn(10, 10, dtype=dtype).to(device="cuda") + target = torch.randint(0, 2, (10,)).to(device="cuda") + + for batch_idx in range(100): + optimizer.zero_grad() + + output = model(input) + loss = F.cross_entropy(output.float(), target) + accumulator.backward(loss) + + if (batch_idx + 1) % accumulation_steps == 0: + + # Manual update weights for ref + with torch.no_grad(): + fc1_grad = accumulator.get_grad_buffer(name="fc1.weight").to(dtype) + expected_fc1_weight = model.fc1.weight - lr1 * fc1_grad + + fc2_grad = accumulator.get_grad_buffer(name="fc2.weight").to(dtype) + expected_fc2_weight = model.fc2.weight - lr2 * fc2_grad + + optimizer.step() + + updated_fc1_weight = model.fc1.weight + updated_fc2_weight = model.fc2.weight + + torch.testing.assert_close(expected_fc1_weight, updated_fc1_weight) + torch.testing.assert_close(expected_fc2_weight, updated_fc2_weight) + + +@pytest.mark.parametrize("half_precision", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("accumulation_steps", [1, 10]) +def test_optimizer_grad_accumulation_lr_weight_decay_one_group(half_precision: torch.dtype, accumulation_steps: int): + set_random_seed(42) + dtype = half_precision + lr1 = 0.1 + weight_decay = 0.1 + + model = DummyModel(dtype=dtype).to("cuda") + + # Need to convert the weights to NanotronParameter for the gradient accumulation to work + model.fc1.weight = NanotronParameter(model.fc1.weight) + model.fc2.weight = NanotronParameter(model.fc2.weight) + + named_params_or_groups = [] + for name, param in model.named_parameters(): + named_params_or_groups.append((name, param)) + named_params_or_groups = [{"named_params": named_params_or_groups, "lr": lr1, "weight_decay": weight_decay}] + + # Optimizer + def optimizer_builder(inp_param_groups): + return NamedOptimizer( + named_params_or_groups=inp_param_groups, + optimizer_builder=lambda param_groups: optim.SGD( + param_groups, + lr=9999999, # this is a dummy value that will be overwritten by the lr in the named_params_or_groups + weight_decay=9999999, # this is a dummy value that will be overwritten by the weight_decay in the named_params_or_groups + ), + ) + + optimizer = OptimizerFromGradientAccumulator( + gradient_accumulator_builder=lambda named_params: FP32GradientAccumulator(named_parameters=named_params), + named_params_or_groups=named_params_or_groups, + optimizer_builder=optimizer_builder, + ) + + accumulator = optimizer.gradient_accumulator + + input = torch.randn(10, 10, dtype=dtype).to(device="cuda") + target = torch.randint(0, 2, (10,)).to(device="cuda") + + for batch_idx in range(100): + optimizer.zero_grad() + + output = model(input) + loss = F.cross_entropy(output.float(), target) + accumulator.backward(loss) + + if (batch_idx + 1) % accumulation_steps == 0: + + # Manual update weights for ref + with torch.no_grad(): + fc1_grad = accumulator.get_grad_buffer(name="fc1.weight").to(dtype) + expected_fc1_weight = (1 - lr1 * weight_decay) * model.fc1.weight - lr1 * fc1_grad + + fc2_grad = accumulator.get_grad_buffer(name="fc2.weight").to(dtype) + expected_fc2_weight = (1 - lr1 * weight_decay) * model.fc2.weight - lr1 * fc2_grad + + optimizer.step() + + updated_fc1_weight = model.fc1.weight + updated_fc2_weight = model.fc2.weight + + torch.testing.assert_close(expected_fc1_weight, updated_fc1_weight) + torch.testing.assert_close(expected_fc2_weight, updated_fc2_weight) + + +@pytest.mark.parametrize("half_precision", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("accumulation_steps", [1, 10]) +def test_optimizer_grad_accumulation_lr_weight_decay_multiple_group( + half_precision: torch.dtype, accumulation_steps: int +): + set_random_seed(42) + dtype = half_precision + lr1, lr2 = 0.1, 0.001 + weight_decay1, weight_decay2 = 0.1, 0.001 + + model = DummyModel(dtype=dtype).to("cuda") + + # Need to convert the weights to NanotronParameter for the gradient accumulation to work + model.fc1.weight = NanotronParameter(model.fc1.weight) + model.fc2.weight = NanotronParameter(model.fc2.weight) + + named_params_or_groups = [ + { + "named_params": [(name, param) for name, param in model.named_parameters() if "fc1" in name], + "lr": lr1, + "weight_decay": weight_decay1, + }, + { + "named_params": [(name, param) for name, param in model.named_parameters() if "fc2" in name], + "lr": lr2, + "weight_decay": weight_decay2, + }, + ] + # Optimizer + def optimizer_builder(inp_param_groups): + return NamedOptimizer( + named_params_or_groups=inp_param_groups, + optimizer_builder=lambda param_groups: optim.SGD( + param_groups, + lr=9999999, # this is a dummy value that will be overwritten by the lr in the named_params_or_groups + weight_decay=9999999, # this is a dummy value that will be overwritten by the weight_decay in the named_params_or_groups + ), + ) + + optimizer = OptimizerFromGradientAccumulator( + gradient_accumulator_builder=lambda named_params: FP32GradientAccumulator(named_parameters=named_params), + named_params_or_groups=named_params_or_groups, + optimizer_builder=optimizer_builder, + ) + + accumulator = optimizer.gradient_accumulator + + input = torch.randn(10, 10, dtype=dtype).to(device="cuda") + target = torch.randint(0, 2, (10,)).to(device="cuda") + + for batch_idx in range(100): + optimizer.zero_grad() + + output = model(input) + loss = F.cross_entropy(output.float(), target) + accumulator.backward(loss) + + if (batch_idx + 1) % accumulation_steps == 0: + + # Manual update weights for ref + with torch.no_grad(): + fc1_grad = accumulator.get_grad_buffer(name="fc1.weight").to(dtype) + expected_fc1_weight = (1 - lr1 * weight_decay1) * model.fc1.weight - lr1 * fc1_grad + + fc2_grad = accumulator.get_grad_buffer(name="fc2.weight").to(dtype) + expected_fc2_weight = (1 - lr2 * weight_decay2) * model.fc2.weight - lr2 * fc2_grad + + optimizer.step() + + updated_fc1_weight = model.fc1.weight + updated_fc2_weight = model.fc2.weight + + torch.testing.assert_close(expected_fc1_weight, updated_fc1_weight) + torch.testing.assert_close(expected_fc2_weight, updated_fc2_weight) + + +@pytest.mark.skipif(available_gpus() < 2, reason="Testing requires at least 2 gpus") +@pytest.mark.parametrize("half_precision", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("accumulation_steps", [1, 10]) +@rerun_if_address_is_in_use() +def test_ddp_optimizer_grad_accumulation_lr_weight_decay_multiple_group( + half_precision: torch.dtype, accumulation_steps: int +): + init_distributed(tp=1, dp=2, pp=1)(_test_ddp_optimizer_grad_accumulation_lr_weight_decay_multiple_group)( + half_precision=half_precision, + accumulation_steps=accumulation_steps, + ) + + +def _test_ddp_optimizer_grad_accumulation_lr_weight_decay_multiple_group( + parallel_context: ParallelContext, half_precision: torch.dtype, accumulation_steps: int +): + set_random_seed(42) + dtype = half_precision + # Making it bigger so that the difference is more visible during update + lr1, lr2 = 0.04, 0.05 + weight_decay1, weight_decay2 = 0.5, 0.2 + + model = DummyModel(dtype=dtype).to("cuda") + # Need to convert the weights to NanotronParameter for the gradient accumulation to work + model.fc1.weight = NanotronParameter(model.fc1.weight) + model.fc2.weight = NanotronParameter(model.fc2.weight) + + model_ddp = torch.nn.parallel.DistributedDataParallel( + model, + process_group=parallel_context.dp_pg, + ) + + named_params_or_groups = [ + { + "named_params": [(name, param) for name, param in model_ddp.named_parameters() if "fc1" in name], + "lr": lr1, + "weight_decay": weight_decay1, + }, + { + "named_params": [(name, param) for name, param in model_ddp.named_parameters() if "fc2" in name], + "lr": lr2, + "weight_decay": weight_decay2, + }, + ] + # Optimizer + def optimizer_builder(inp_param_groups): + return NamedOptimizer( + named_params_or_groups=inp_param_groups, + optimizer_builder=lambda param_groups: optim.SGD( + param_groups, + lr=9999999, # this is a dummy value that will be overwritten by the lr in the named_params_or_groups + weight_decay=9999999, # this is a dummy value that will be overwritten by the weight_decay in the named_params_or_groups + ), + ) + + optimizer = OptimizerFromGradientAccumulator( + gradient_accumulator_builder=lambda named_params: FP32GradientAccumulator(named_parameters=named_params), + named_params_or_groups=named_params_or_groups, + optimizer_builder=optimizer_builder, + ) + + accumulator = optimizer.gradient_accumulator + + input = torch.randn(10, 10, dtype=dtype).to(device="cuda") + target = torch.randint(0, 2, (10,)).to(device="cuda") + + for batch_idx in range(100): + optimizer.zero_grad() + + output = model(input) + loss = F.cross_entropy(output.float(), target) + accumulator.backward(loss) + + if (batch_idx + 1) % accumulation_steps == 0: + + # Manual update weights for ref + with torch.no_grad(): + fc1_grad = accumulator.get_grad_buffer(name="module.fc1.weight").to(dtype) + expected_fc1_weight = (1 - lr1 * weight_decay1) * model.fc1.weight - lr1 * fc1_grad + + fc2_grad = accumulator.get_grad_buffer(name="module.fc2.weight").to(dtype) + expected_fc2_weight = (1 - lr2 * weight_decay2) * model.fc2.weight - lr2 * fc2_grad + + optimizer.step() + + updated_fc1_weight = model.fc1.weight + updated_fc2_weight = model.fc2.weight + + torch.testing.assert_close(expected_fc1_weight, updated_fc1_weight) + torch.testing.assert_close(expected_fc2_weight, updated_fc2_weight) diff --git a/tests/test_parametrization.py b/tests/test_parametrization.py new file mode 100644 index 00000000..fe76826a --- /dev/null +++ b/tests/test_parametrization.py @@ -0,0 +1,76 @@ +import math +from typing import Union + +import pytest +import torch +from helpers.llama import TINY_LLAMA_CONFIG, create_llama_from_config, get_llama_training_config +from helpers.utils import init_distributed, rerun_if_address_is_in_use +from nanotron.config import ModelArgs, RandomInit, SpectralMupInit +from nanotron.parallel import ParallelContext +from nanotron.scaling.parametrization import ParametrizationMethod + + +@pytest.mark.parametrize("tp,dp,pp", [(2, 1, 1)]) +@pytest.mark.parametrize("parametrization_method", [ParametrizationMethod.SPECTRAL_MUP]) +@pytest.mark.skip +@rerun_if_address_is_in_use() +def test_parametrization(tp: int, dp: int, pp: int, parametrization_method: ParametrizationMethod): + if parametrization_method == ParametrizationMethod.STANDARD: + init_method = RandomInit(std=1.0) + elif parametrization_method == ParametrizationMethod.SPECTRAL_MUP: + init_method = SpectralMupInit(use_mup=True) + + init_distributed(tp=tp, dp=dp, pp=pp)(_test_parametrization)( + init_method=init_method, + parametrization_method=parametrization_method, + ) + + +def _test_parametrization( + parallel_context: ParallelContext, + init_method: Union[RandomInit, SpectralMupInit], + parametrization_method: ParametrizationMethod, +): + def spectral_std(fan_in: int, fan_out: int): + return torch.tensor((1.0 / math.sqrt(fan_in)) * min(1, math.sqrt(fan_out / fan_in))) + + model_args = ModelArgs(init_method=init_method, model_config=TINY_LLAMA_CONFIG) + config = get_llama_training_config(model_args) + + llama = create_llama_from_config( + model_config=TINY_LLAMA_CONFIG, + device=torch.device("cuda"), + parallel_context=parallel_context, + ) + llama.init_model_randomly(config=config, init_method=parametrization_method) + + hidden_size = TINY_LLAMA_CONFIG.hidden_size + interdimte_size = TINY_LLAMA_CONFIG.intermediate_size + + o_proj_infeatures = llama.model.decoder[0].pp_block.attn.o_proj.in_features * parallel_context.tensor_parallel_size + NAME_TO_EXPECTED_STD = { + "input_layernorm": torch.tensor(0.0), + "post_attention_layernorm": torch.tensor(0.0), + "final_layer_norm": torch.tensor(0.0), + "token_embedding": torch.tensor(1.0), + # "lm_head": torch.tensor(1.0), + "qkv_proj": spectral_std(fan_in=hidden_size, fan_out=interdimte_size), + "o_proj": spectral_std(fan_in=o_proj_infeatures, fan_out=hidden_size), + "gate_up_proj": spectral_std(fan_in=hidden_size, fan_out=interdimte_size), + "down_proj": spectral_std(fan_in=interdimte_size, fan_out=hidden_size), + } + + def find_expected_std(param_name): + for name in NAME_TO_EXPECTED_STD: + if name in param_name: + return NAME_TO_EXPECTED_STD[name] + + for name, param in llama.model.named_parameters(): + if "lm_head" in name: + continue + + expected_std = find_expected_std(name) + assert expected_std is not None, f"Could not find expected std for {name}" + assert torch.allclose( + param.std().float(), expected_std, atol=0.05 + ), f"name: {name}, expected: {expected_std}, actual: {param.std()}"