From 92cb5004a9c1e5bb672c807c3b6c88877e2a8f43 Mon Sep 17 00:00:00 2001 From: NouamaneTazi Date: Sun, 5 May 2024 23:54:56 +0000 Subject: [PATCH 1/6] add custom dl example --- examples/custom-dataloader/README.md | 23 ++ .../custom-dataloader/config_custom_dl.yaml | 109 +++++++++ examples/custom-dataloader/run_train.py | 222 ++++++++++++++++++ 3 files changed, 354 insertions(+) create mode 100644 examples/custom-dataloader/README.md create mode 100644 examples/custom-dataloader/config_custom_dl.yaml create mode 100644 examples/custom-dataloader/run_train.py diff --git a/examples/custom-dataloader/README.md b/examples/custom-dataloader/README.md new file mode 100644 index 00000000..f65c13ac --- /dev/null +++ b/examples/custom-dataloader/README.md @@ -0,0 +1,23 @@ +# Use a custom dataloader with Nanotron + +## Usage +``` +export CUDA_DEVICE_MAX_CONNECTIONS=1 # important for some distributed operations +torchrun --nproc_per_node=2 examples/custom-dataloader/run_train.py --config-file examples/custom-dataloader/config_custom_dl.yaml +``` + +## Troubleshooting + +### `return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)` +``` + File "/fsx/nouamane/projects/nanotron/src/nanotron/parallel/tensor_parallel/nn.py", line 284, in forward + out = super().forward(masked_input) + File "/fsx/nouamane/miniconda/envs/2-1-cu121/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 162, in forward + return F.embedding( + File "/fsx/nouamane/miniconda/envs/2-1-cu121/lib/python3.10/site-packages/torch/nn/functional.py", line 2233, in embedding + return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse) +RuntimeError: CUDA error: device-side assert triggered +Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions. +``` + +If you encounter an error with `torch.embedding`, it's probable you're feeding a token which is bigger than the model's vocabulary size. Check your model's vocab size and tokenizer diff --git a/examples/custom-dataloader/config_custom_dl.yaml b/examples/custom-dataloader/config_custom_dl.yaml new file mode 100644 index 00000000..81941f1f --- /dev/null +++ b/examples/custom-dataloader/config_custom_dl.yaml @@ -0,0 +1,109 @@ +checkpoints: + checkpoint_interval: 10 + checkpoints_path: checkpoints + checkpoints_path_is_shared_file_system: false + resume_checkpoint_path: null + save_initial_state: false +data_stages: +- data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 1 + hf_dataset_config_name: null + hf_dataset_or_datasets: stas/openwebtext-10k + hf_dataset_splits: train + text_column_name: text + num_loading_workers: 1 + seed: 42 + name: Stable Training Stage + start_training_step: 1 +- data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 1 + hf_dataset_config_name: null + hf_dataset_or_datasets: stas/openwebtext-10k + hf_dataset_splits: train + text_column_name: text + num_loading_workers: 1 + seed: 42 + name: Annealing Phase + start_training_step: 10 +general: + benchmark_csv_path: null + consumed_train_samples: null + ignore_sanity_checks: true + project: debug + run: tiny_llama_%date_%jobid + seed: 42 + step: null +lighteval: null +logging: + iteration_step_info_interval: 1 + log_level: info + log_level_replica: info +model: + ddp_bucket_cap_mb: 25 + dtype: bfloat16 + init_method: + std: 0.025 + make_vocab_size_divisible_by: 1 + model_config: + bos_token_id: 1 + eos_token_id: 2 + hidden_act: silu + hidden_size: 16 + initializer_range: 0.02 + intermediate_size: 64 + is_llama_config: true + max_position_embeddings: 256 + num_attention_heads: 4 + num_hidden_layers: 2 + num_key_value_heads: 4 + pad_token_id: null + pretraining_tp: 1 + rms_norm_eps: 1.0e-05 + rope_scaling: null + tie_word_embeddings: true + use_cache: true + vocab_size: 256 +optimizer: + accumulate_grad_in_fp32: true + clip_grad: 1.0 + learning_rate_scheduler: + learning_rate: 0.0003 + lr_decay_starting_step: null + lr_decay_steps: 13 + lr_decay_style: cosine + lr_warmup_steps: 2 + lr_warmup_style: linear + min_decay_lr: 1.0e-05 + optimizer_factory: + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-08 + name: adamW + torch_adam_is_fused: true + weight_decay: 0.01 + zero_stage: 0 +parallelism: + dp: 2 + expert_parallel_size: 1 + pp: 1 + pp_engine: 1f1b + tp: 1 + tp_linear_async_communication: true + tp_mode: REDUCE_SCATTER +profiler: null +tokenizer: + tokenizer_max_length: null + tokenizer_name_or_path: robot-test/dummy-tokenizer-wordlevel + tokenizer_revision: null +tokens: + batch_accumulation_per_replica: 1 + limit_test_batches: 0 + limit_val_batches: 0 + micro_batch_size: 2 + sequence_length: 256 + train_steps: 15 + val_check_interval: -1 diff --git a/examples/custom-dataloader/run_train.py b/examples/custom-dataloader/run_train.py new file mode 100644 index 00000000..e1995381 --- /dev/null +++ b/examples/custom-dataloader/run_train.py @@ -0,0 +1,222 @@ +""" +Nanotron training script example using a custom dataloader. + +Usage: +``` +export CUDA_DEVICE_MAX_CONNECTIONS=1 # important for some distributed operations +torchrun --nproc_per_node=2 examples/custom-dataloader/run_train.py --config-file examples/custom-dataloader/config_custom_dl.yaml +``` +""" +import argparse +from typing import Dict, cast + +import datasets +import numpy as np +from nanotron import logging +from nanotron.config import ( + DataArgs, + DatasetStageArgs, + PretrainDatasetsArgs, +) +from nanotron.dataloader import ( + DataCollatorForCLM, + clm_process, + get_dataloader_worker_init, + get_datasets, + get_train_dataloader, +) +from nanotron.helpers import ( + compute_remain_train_steps_of_a_data_stage_from_ckp, + get_consumed_train_samples_of_a_data_stage_from_ckp, +) +from nanotron.logging import log_rank +from nanotron.parallel.pipeline_parallel.utils import get_input_output_pp_ranks +from nanotron.trainer import DistributedTrainer +from nanotron.utils import main_rank_first +from torch.utils.data import DataLoader + +try: + from huggingface_hub import __version__ as hf_hub_version + from transformers import AutoTokenizer + from transformers import __version__ as tf_version +except ImportError: + hf_hub_version = None + tf_version = None + +logger = logging.get_logger(__name__) + + +def get_dataloader_from_data_stage( + trainer: DistributedTrainer, + data: DataArgs, + consumed_train_samples: int, + num_remaining_train_steps: int, +): + """ + Returns a dataloader for a given data stage. + + data: The data configuration for the current stage. + consumed_train_samples: The number of samples consumed by the model in the this stage (each stage starts from zero). + num_remaining_train_steps: The number of remaining training steps for this stage. + """ + assert consumed_train_samples >= 0, "consumed_train_samples should be greater than 0" + assert num_remaining_train_steps >= 0, "num_remaining_train_steps should be greater than 0" + + # First, we need to know which ranks to feed the dataloader to + input_pp_rank, output_pp_rank = get_input_output_pp_ranks(model=trainer.model) + + # Case 1: custom data generator + if data.dataset is None: + log_rank("Using custom data generator", logger=logger, level=logging.INFO, rank=0) + + ########################################################################################################### + # This can be replaced with your own tokenized data generator + ########################################################################################################### + train_dataset = datasets.Dataset.from_dict( + { + "input_ids": np.random.randint( + 0, + trainer.config.model.model_config.vocab_size, + (trainer.global_batch_size * num_remaining_train_steps, trainer.sequence_length + 1), + ), + } + ) + ########################################################################################################### + + data_collator = DataCollatorForCLM( + sequence_length=trainer.sequence_length, + input_pp_rank=input_pp_rank, + output_pp_rank=output_pp_rank, + parallel_context=trainer.parallel_context, + ) + + return DataLoader( + train_dataset, + batch_size=trainer.micro_batch_size, + collate_fn=data_collator, + drop_last=True, + num_workers=0, + pin_memory=True, + worker_init_fn=get_dataloader_worker_init(dp_rank=trainer.parallel_context.dp_pg.rank()), + ) + + # Case 2: HuggingFace datasets + elif isinstance(data.dataset, PretrainDatasetsArgs): + log_rank("Using `datasets` library", logger=logger, level=logging.INFO, rank=0) + tokenizer_path = trainer.config.tokenizer.tokenizer_name_or_path + log_rank( + f"Loading tokenizer from {tokenizer_path} and transformers/hf_hub versions {tf_version, hf_hub_version}", + logger=logger, + level=logging.INFO, + rank=0, + ) + + # We need to the 1st device to process dataset and cache it, then other devices load from cache + with main_rank_first(trainer.parallel_context.world_pg): + # We load the raw dataset + raw_dataset = get_datasets( + hf_dataset_or_datasets=data.dataset.hf_dataset_or_datasets, + hf_dataset_config_name=data.dataset.hf_dataset_config_name, + splits=data.dataset.hf_dataset_splits, + )["train"] + + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "left" + + # We apply the Causal Language Modeling preprocessing + train_dataset = clm_process( + raw_dataset=raw_dataset, + tokenizer=tokenizer, + text_column_name=data.dataset.text_column_name, + dataset_processing_num_proc_per_process=data.dataset.dataset_processing_num_proc_per_process, + dataset_overwrite_cache=data.dataset.dataset_overwrite_cache, + sequence_length=trainer.sequence_length, + ) + + # We load the processed dataset on the ranks requiring it + dataloader = get_train_dataloader( + train_dataset=train_dataset, + sequence_length=trainer.sequence_length, + parallel_context=trainer.parallel_context, + input_pp_rank=input_pp_rank, + output_pp_rank=output_pp_rank, + micro_batch_size=trainer.micro_batch_size, + consumed_train_samples=consumed_train_samples, + dataloader_num_workers=data.num_loading_workers, + seed_worker=data.seed, + dataloader_drop_last=True, + ) + + # Check if we have enough samples for train_steps + total_tokens_dataset = len(dataloader.dataset) * trainer.sequence_length + num_tokens_needed_for_training = ( + num_remaining_train_steps * trainer.global_batch_size * trainer.sequence_length + ) + assert num_tokens_needed_for_training <= total_tokens_dataset, ( + f"Dataset is too small for steps ({total_tokens_dataset} < {num_tokens_needed_for_training}), " + f"Try train_steps<={len(dataloader.dataset) // trainer.global_batch_size + trainer.iteration_step}" + ) + else: + raise ValueError(f"Unhandled case of `self.config.data.dataset`. Got: {data.dataset}") + + return dataloader + + +def get_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]: + dataloaders = {} + + for stage_idx, stage in enumerate(trainer.config.data_stages): + # NOTE: we only create the dataloader for the first stage, + # then we lazy initialize the dataloader for the other stages + stage = cast(DatasetStageArgs, stage) + consumed_train_samples = get_consumed_train_samples_of_a_data_stage_from_ckp(stage, trainer.metadata) + assert ( + consumed_train_samples is not None + ), f"Cannot find consumed_train_samples for stage {stage.start_training_step} in the checkpoint" + + num_remaining_train_steps = compute_remain_train_steps_of_a_data_stage_from_ckp( + stage, trainer.config, trainer.metadata + ) + log_rank( + f"[Training Plan] Stage {stage.name} has {num_remaining_train_steps} remaining training steps and has consumed {consumed_train_samples} samples", + logger=logger, + level=logging.INFO, + rank=0, + ) + + dataloader = ( + get_dataloader_from_data_stage( + trainer, + stage.data, + consumed_train_samples=consumed_train_samples, + num_remaining_train_steps=num_remaining_train_steps, + ) + if stage_idx == 0 + else lambda stage=stage: get_dataloader_from_data_stage( + trainer, + stage.data, + consumed_train_samples=consumed_train_samples, + num_remaining_train_steps=num_remaining_train_steps, + ) + ) + dataloaders[stage.name] = dataloader + return dataloaders + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--config-file", type=str, required=True, help="Path to the YAML or python config file") + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args() + config_file = args.config_file + + # Load trainer and data + trainer = DistributedTrainer(config_file) + dataloader = get_dataloader(trainer) + + # Train + trainer.train(dataloader) From 2e21db0db46a40bedbd03714616dd0ae4ea75914 Mon Sep 17 00:00:00 2001 From: NouamaneTazi Date: Sun, 5 May 2024 23:57:00 +0000 Subject: [PATCH 2/6] add assert vocab_size matches model's vocab_size --- examples/config_tiny_llama.py | 4 ++-- examples/config_tiny_llama.yaml | 4 ++-- run_train.py | 5 +++++ 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/examples/config_tiny_llama.py b/examples/config_tiny_llama.py index 765a353d..479e1d47 100644 --- a/examples/config_tiny_llama.py +++ b/examples/config_tiny_llama.py @@ -102,7 +102,7 @@ ), ] -checkpoints_path = os.path.dirname(os.path.dirname(__file__)) + "/checkpoints" +checkpoints_path = "./checkpoints" os.makedirs(checkpoints_path, exist_ok=True) config = Config( @@ -110,7 +110,7 @@ checkpoints=CheckpointsArgs(checkpoints_path=checkpoints_path, checkpoint_interval=10), parallelism=parallelism, model=ModelArgs(init_method=RandomInit(std=0.025), model_config=model_config), - tokenizer=TokenizerArgs("gpt2"), + tokenizer=TokenizerArgs("robot-test/dummy-tokenizer-wordlevel"), optimizer=optimizer, logging=LoggingArgs(), tokens=tokens, diff --git a/examples/config_tiny_llama.yaml b/examples/config_tiny_llama.yaml index ab358b05..58645e2d 100644 --- a/examples/config_tiny_llama.yaml +++ b/examples/config_tiny_llama.yaml @@ -1,6 +1,6 @@ checkpoints: checkpoint_interval: 10 - checkpoints_path: /fsx/nouamane/projects/nanotron/checkpoints + checkpoints_path: checkpoints checkpoints_path_is_shared_file_system: false resume_checkpoint_path: null save_initial_state: false @@ -97,7 +97,7 @@ parallelism: profiler: null tokenizer: tokenizer_max_length: null - tokenizer_name_or_path: gpt2 + tokenizer_name_or_path: robot-test/dummy-tokenizer-wordlevel tokenizer_revision: null tokens: batch_accumulation_per_replica: 1 diff --git a/run_train.py b/run_train.py index 617d231b..fd346313 100644 --- a/run_train.py +++ b/run_train.py @@ -102,6 +102,11 @@ def get_dataloader_from_data_stage( tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "left" + # Check that tokenizer's vocab size is smaller than the model's vocab size + assert ( + tokenizer.vocab_size <= trainer.model_config.vocab_size + ), f"Tokenizer's vocab size ({tokenizer.vocab_size}) is larger than the model's vocab size ({trainer.model_config.vocab_size})" + # We apply the Causal Language Modeling preprocessing train_dataset = clm_process( raw_dataset=raw_dataset, From 9e211e9079b465c8bd8b9ac8c3d80ae2c901cff8 Mon Sep 17 00:00:00 2001 From: NouamaneTazi Date: Mon, 6 May 2024 00:02:41 +0000 Subject: [PATCH 3/6] readme --- examples/custom-dataloader/README.md | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/examples/custom-dataloader/README.md b/examples/custom-dataloader/README.md index f65c13ac..eb2ef520 100644 --- a/examples/custom-dataloader/README.md +++ b/examples/custom-dataloader/README.md @@ -1,10 +1,9 @@ # Use a custom dataloader with Nanotron -## Usage -``` -export CUDA_DEVICE_MAX_CONNECTIONS=1 # important for some distributed operations -torchrun --nproc_per_node=2 examples/custom-dataloader/run_train.py --config-file examples/custom-dataloader/config_custom_dl.yaml -``` +This example shows how to use a custom dataloader with Nanotron. We will use a simple dataloader that loads a random tokenized dataset and feeds it to a Nanotron model. +https://github.com/huggingface/nanotron/blob/2e21db0db46a40bedbd03714616dd0ae4ea75914/examples/custom-dataloader/run_train.py#L72-L84 + +`DataCollatorForCLM` is a custom data collator that takes a list of input_ids and returns a dictionary with the input_ids and the labels on the ranks which need it. For example `input_ids` are only needed in the first PP rank, while `labels` are needed in the last PP rank. ## Troubleshooting From 8eb838813267153b01ec65708a495d7991d62f5f Mon Sep 17 00:00:00 2001 From: NouamaneTazi Date: Mon, 6 May 2024 00:03:37 +0000 Subject: [PATCH 4/6] . --- examples/custom-dataloader/README.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/examples/custom-dataloader/README.md b/examples/custom-dataloader/README.md index eb2ef520..54d98a5a 100644 --- a/examples/custom-dataloader/README.md +++ b/examples/custom-dataloader/README.md @@ -5,6 +5,13 @@ https://github.com/huggingface/nanotron/blob/2e21db0db46a40bedbd03714616dd0ae4ea `DataCollatorForCLM` is a custom data collator that takes a list of input_ids and returns a dictionary with the input_ids and the labels on the ranks which need it. For example `input_ids` are only needed in the first PP rank, while `labels` are needed in the last PP rank. +To try it out you can run the following command: + +```bash +export CUDA_DEVICE_MAX_CONNECTIONS=1 # important for some distributed operations +torchrun --nproc_per_node=2 examples/custom-dataloader/run_train.py --config-file examples/custom-dataloader/config_custom_dl.yaml +``` + ## Troubleshooting ### `return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)` From 8864d94f3b7a44858fe212bcc2dfdbe5fc3e0140 Mon Sep 17 00:00:00 2001 From: NouamaneTazi Date: Mon, 6 May 2024 00:06:03 +0000 Subject: [PATCH 5/6] . --- examples/custom-dataloader/README.md | 10 ++++++++++ examples/custom-dataloader/config_custom_dl.yaml | 8 +------- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/examples/custom-dataloader/README.md b/examples/custom-dataloader/README.md index 54d98a5a..b705794e 100644 --- a/examples/custom-dataloader/README.md +++ b/examples/custom-dataloader/README.md @@ -5,6 +5,16 @@ https://github.com/huggingface/nanotron/blob/2e21db0db46a40bedbd03714616dd0ae4ea `DataCollatorForCLM` is a custom data collator that takes a list of input_ids and returns a dictionary with the input_ids and the labels on the ranks which need it. For example `input_ids` are only needed in the first PP rank, while `labels` are needed in the last PP rank. +And to test it out, you should fix your config to have: (example: [config_custom_dl.yaml](config_custom_dl.yaml)) +```yaml +- data: + dataset: null + num_loading_workers: 1 + seed: 42 + name: Stable Training Stage + start_training_step: 1 +``` + To try it out you can run the following command: ```bash diff --git a/examples/custom-dataloader/config_custom_dl.yaml b/examples/custom-dataloader/config_custom_dl.yaml index 81941f1f..8bac5db6 100644 --- a/examples/custom-dataloader/config_custom_dl.yaml +++ b/examples/custom-dataloader/config_custom_dl.yaml @@ -6,13 +6,7 @@ checkpoints: save_initial_state: false data_stages: - data: - dataset: - dataset_overwrite_cache: false - dataset_processing_num_proc_per_process: 1 - hf_dataset_config_name: null - hf_dataset_or_datasets: stas/openwebtext-10k - hf_dataset_splits: train - text_column_name: text + dataset: null num_loading_workers: 1 seed: 42 name: Stable Training Stage From aac6e7b11de3a103571544622dee75499c5c2b68 Mon Sep 17 00:00:00 2001 From: NouamaneTazi Date: Mon, 6 May 2024 00:06:56 +0000 Subject: [PATCH 6/6] . --- examples/custom-dataloader/README.md | 2 +- examples/custom-dataloader/config_custom_dl.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/custom-dataloader/README.md b/examples/custom-dataloader/README.md index b705794e..9ded4b3a 100644 --- a/examples/custom-dataloader/README.md +++ b/examples/custom-dataloader/README.md @@ -8,7 +8,7 @@ https://github.com/huggingface/nanotron/blob/2e21db0db46a40bedbd03714616dd0ae4ea And to test it out, you should fix your config to have: (example: [config_custom_dl.yaml](config_custom_dl.yaml)) ```yaml - data: - dataset: null + dataset: null # Custom dataloader will be used num_loading_workers: 1 seed: 42 name: Stable Training Stage diff --git a/examples/custom-dataloader/config_custom_dl.yaml b/examples/custom-dataloader/config_custom_dl.yaml index 8bac5db6..970e7407 100644 --- a/examples/custom-dataloader/config_custom_dl.yaml +++ b/examples/custom-dataloader/config_custom_dl.yaml @@ -6,7 +6,7 @@ checkpoints: save_initial_state: false data_stages: - data: - dataset: null + dataset: null # Custom dataloader will be used num_loading_workers: 1 seed: 42 name: Stable Training Stage