Skip to content

Commit

Permalink
Merge pull request #54 from huggingface/xrsrke/feature_doremi_new_cod…
Browse files Browse the repository at this point in the history
…ebase

[Feature] DoReMi
  • Loading branch information
xrsrke authored Feb 22, 2024
2 parents 53c3064 + 0dd67f7 commit 9f9af42
Show file tree
Hide file tree
Showing 32 changed files with 2,951 additions and 22 deletions.
15 changes: 13 additions & 2 deletions .github/workflows/3d_parallelism_unit_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ jobs:
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
- name: Instal nanotron
- name: Install nanotron's dependencies
run: |
python -m pip install --upgrade pip
pip install packaging
Expand All @@ -49,7 +49,7 @@ jobs:
- name: Show installed libraries and their versions
run: pip freeze | tee installed.txt

- name: Run tests
- name: Run nanotron tests
# NOTE: -m "not fa2" will run all the unit tests that don't have the mark
# "fa2" (these are FA2-related tests, we can't run it on T4)
run: |
Expand All @@ -61,3 +61,14 @@ jobs:
--ignore tests/fp8 \
--verbose \
tests/
# NOTE: T4 can't run FA2, DoReMi's LLaMa needs FÀ
# - name: Run DoReMi tests
# # NOTE: -m "not fa2" will run all the unit tests that don't have the mark
# # "fa2" (these are FA2-related tests, we can't run it on T4)
# run: |
# pip install -r examples/doremi/requirements.txt && \
# pytest \
# --color=yes \
# --durations=0 \
# --verbose \
# examples/doremi/tests/
14 changes: 14 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,17 @@ repos:
args:
- --fix
- --exit-non-zero-on-fix
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
hooks:
- id: isort
args:
- --profile=black
- --skip-glob=wandb/**/*
- --thirdparty=wandb
- repo: https://github.com/codespell-project/codespell
rev: v2.1.0
hooks:
- id: codespell
args:
- --ignore-words-list=nd,reacher,thist,ths,magent,ba,fo
88 changes: 88 additions & 0 deletions examples/doremi/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# DoReMi: Optimizing Data Mixtures Speeds Up Language Model Pretraining
Paper: https://arxiv.org/abs/2305.10429

You might think that one of the key ways to speed up pretraining performance is either by finding more quality data, increasing FLOPs, or changing the model architecture, but actually, these are not the only ways. DoReMi shows that, given the same source of training data, a model using an optimal data mixing strategy could outperform its counterpart with random sampling in at least 70% domains or all domains and downstream evaluations without any knowledge of the downstream evaluation tasks.

In our implementation, the experiment results show that doremi outperforms 15 out of 22 domains on test set and has a lower average cross entropy test loss. Here are the comparison of the training losses between:

- 280M proxy and reference model [[link]](https://wandb.ai/neuralink/nanotron/reports/-DoReMi-280m-reference-vs-280m-proxy-s-training--Vmlldzo2NzYwNTU1)
- 2.5B reference and tuned weight models [[link]](https://wandb.ai/neuralink/nanotron/reports/-DoReMi-2-5B-tuned-weights-vs-2-5B-token-ratio-domain-weights-s-training--Vmlldzo2NzYwNzE2)
- And how the 280M proxy model's domain weights change during training [[link]](https://wandb.ai/neuralink/nanotron/runs/j9ojbso1?workspace=user-neuralink)


![The domains in which we outperform](./assets/outperform.png)


![The domains in which we don't outperform](./assets/not_outperform.png)


![Domain weights comparison](./assets/domain_weights.png)

**Notes**: The graph above represent test losses, not validation losses (this is a typo 🫠). The x-axis doesn't mean anything, it simply means sampling another batch of evaluation data from the same final checkpoint.

### How it works

- Step 0: Preprocessing data

- Step 1: Train a small reference model using uniform sampling from each domain (for a given global batch size, you equally sample `x` samples across all domains, or in some cases, a domain has a smaller amount of samples than other domains. This leads to some domains running out of samples early, so you could enable automatic domain weights based on the token count).

```bash
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=4 examples/doremi/train_reference.py --config-file examples/doremi/configs/config_280m_llama.yaml
```

- Step 2: Use the trained reference model from step 1 to train an identical model, and use its performance to dynamically tune the domain weights during training.

```bash
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=4 examples/doremi/train_doremi.py --config-file examples/doremi/configs/config_280m_llama_proxy.yaml
```

- Step 3: Nanotron saves the domain weights in the model checkpoint. Now, calculate the optimal domain weights by averaging the domain weights across all training steps from step 1: $\bar{\alpha}=\frac{1}{T} \sum_{i=1}^T \alpha_t$.


```python

import torch

domain_weights = torch.load("checkpoints/doremi/proxy-280m-llama/doremi_domain_weights_100000.pt")

total_weights = sum(d["domain_weights"] for d in domain_weights)
avg_weights = total_weights / len(domain_weights)
```

Then, set these `avg_weights` in the config of the larger run in the `doremi` section.

- Step 4: Use the optimized domain weights from step 3 to train a larger model (could be 10x to 30x larger).

```bash
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=8 examples/doremi/train_reference.py --config-file examples/doremi/configs/config_2.8b_llama_with_tuned_weights.yaml
```

### Dataset

We expect the dataset path to link to a folder that already has tokenized data in the structure:

```
dataset
domain_0
...
domain_1
...
domain_2
...
```

For each tokenized sample, we expect a column name `domain_ids` which contains the domain index of that domain in the dataset. For example, if a sample is from the third domain, it should have a `domain_ids` equal to 2, and the folder names are the same as the domain names that you provide in the DoReMi config

### The Experiment

We first train a small 280M model for 70k steps on the Pile to obtain a reference model. Then, we use the reference model to tune the domain weights of that same model, where we train from scratch (aka: proxy training) for 70k steps.

The reference model's performance is used as a baseline to determine how difficult a domain is, so that the DoReMi algorithm can adjust the model weights accordingly on-the-fly. Once we obtain the optimized weights, we use them to train a 2.5B model (9x larger than the reference model) for 70k steps and train another one based on the token ratio domain weights (this is technically the same as random sampling, since the probability of a token occurring in the training data is the same as its token ratio).

For evaluation, we do uniform sampling on the test set to evaluate a 2.5B model with optimized domain weights and token ratio domain weights. For more details on hyperparameters, please check the config YAML. Here are the model checkpoints in the experiment:
- 280M LLaMA reference model: https://huggingface.co/nanotron/doremi-llama-280m-reference
- 280m LLAMA proxy model: https://huggingface.co/nanotron/doremi-llama-280m-proxy
- 2.5B LLaMA reference model: https://huggingface.co/nanotron/doremi-llama-2.5b-reference
- 2.5B llama trained using the optimized weights: https://huggingface.co/nanotron/doremi-llama-2.5b-optimized-weights

and the dataset: https://huggingface.co/datasets/nanotron/the-pile-for-doremi
Empty file added examples/doremi/__init__.py
Empty file.
Binary file added examples/doremi/assets/domain_weights.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/doremi/assets/not_outperform.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/doremi/assets/outperform.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
121 changes: 121 additions & 0 deletions examples/doremi/configs/config_2.8b_llama.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
checkpoints:
checkpoint_interval: 1000
checkpoints_path: checkpoints/doremi/big-run-02/reference-2.8b-llama
checkpoints_path_is_shared_file_system: true
resume_checkpoint_path: checkpoints/doremi/big-run-02/reference-2.8b-llama/70000
save_initial_state: false

doremi:
domain_names: Pile-CC, Github, OpenWebText2, StackExchange, Wikipedia (en), PubMed Abstracts, USPTO Backgrounds, FreeLaw, PubMed Central, Enron Emails, HackerNews, NIH ExPorter, Books3, ArXiv, DM Mathematics, OpenSubtitles, Gutenberg (PG-19), Ubuntu IRC, BookCorpus2, EuroParl, YoutubeSubtitles, PhilPapers
# domain_weights: 0.1500, 0.1213, 0.0872, 0.0631, 0.0340, 0.0240, 0.0281, 0.0594, 0.1599, 0.0015, 0.0058, 0.0021, 0.0605, 0.1136, 0.0209, 0.0154, 0.0202, 0.0037, 0.0065, 0.0100, 0.0093, 0.0036

data:
dataset:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 1
hf_dataset_config_name: null

hf_dataset_or_datasets: project_data/doremi/datasets/the_pile_raw/tokenized_data/train

num_loading_workers: 1
seed: 42
general:
benchmark_csv_path: null
consumed_train_samples: null
ignore_sanity_checks: true
project: nanotron
run: train_2.8b_llama_reference
seed: 42
step: null
logging:
iteration_step_info_interval: 1
log_level: info
log_level_replica: info
model:
ddp_bucket_cap_mb: 120
dtype: bfloat16
init_method:
std: 0.025
make_vocab_size_divisible_by: 1
model_config:
bos_token_id: 1
eos_token_id: 2
hidden_act: silu
# NOTE: only change hidden_size, intermediate_size,
# num_attention_heads, num_key_value_heads and num_hidden_layers
hidden_size: 4096
initializer_range: 0.02
intermediate_size: 24576
is_llama_config: true
max_position_embeddings: 256
num_attention_heads: 32
num_hidden_layers: 6
# num_hidden_layers: 1
num_key_value_heads: 16
pad_token_id: null
pretraining_tp: 1
rms_norm_eps: 1.0e-05
rope_scaling: null
tie_word_embeddings: true
use_cache: true
vocab_size: 49152
optimizer:
accumulate_grad_in_fp32: true
adam_beta1: 0.9
adam_beta2: 0.95
adam_eps: 1.0e-08
clip_grad: 1.0
learning_rate_scheduler:
learning_rate: 0.0003
lr_decay_steps: 8
lr_decay_style: cosine
lr_warmup_steps: 2
lr_warmup_style: linear
min_decay_lr: 1.0e-05
torch_adam_is_fused: true
weight_decay: 0.01
zero_stage: 0
parallelism:
# dp: 8
# # dp: 2
# pp: 1
# tp: 8
# # tp: 2

# NOTE: for running eval
dp: 8
pp: 1
tp: 2

pp_engine: 1f1b
tp_linear_async_communication: true
tp_mode: REDUCE_SCATTER
profiler: null
tokenizer:
tokenizer_max_length: null
tokenizer_name_or_path: gpt2
tokenizer_revision: null
tokens:
# batch_accumulation_per_replica * micro_batch_size * dp = 4 * 8 * 16 = 512
# batch_accumulation_per_replica * micro_batch_size * dp = 16 * 8 * 4 = 512
# 240 * 1024 = 245760
# the doremi paper do 500k tokens per batch
# batch_accumulation_per_replica: 16
# NOTE: some weird bug, where if you run batch_accumulation_per_replica=16
# it results no samples from some domainsbatch_accumulation_per_replica

# NOTE: this causes some domain losses are 0
# batch_accumulation_per_replica: 8
# micro_batch_size: 8

batch_accumulation_per_replica: 1
micro_batch_size: 64

limit_test_batches: 0
# NOTE: this is like the number of microbatches for validation
limit_val_batches: 1
sequence_length: 1024
# train_steps: 1000
# train_steps: 1579
train_steps: 70_000
val_check_interval: 2
119 changes: 119 additions & 0 deletions examples/doremi/configs/config_2.8b_llama_with_tuned_weights.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
checkpoints:
checkpoint_interval: 5000
checkpoints_path: checkpoints/doremi/big-run-02/reference-2.8b-llama-tuned-weights_with_100k_proxy
checkpoints_path_is_shared_file_system: true
resume_checkpoint_path: checkpoints/doremi/big-run-02/reference-2.8b-llama-tuned-weights_with_100k_proxy/70000
save_initial_state: false

doremi:
domain_names: Pile-CC, Github, OpenWebText2, StackExchange, Wikipedia (en), PubMed Abstracts, USPTO Backgrounds, FreeLaw, PubMed Central, Enron Emails, HackerNews, NIH ExPorter, Books3, ArXiv, DM Mathematics, OpenSubtitles, Gutenberg (PG-19), Ubuntu IRC, BookCorpus2, EuroParl, YoutubeSubtitles, PhilPapers
# domain_weights: 0.2333, 0.0700, 0.1154, 0.0528, 0.0665, 0.0670, 0.0366, 0.0571, 0.0451, 0.0036, 0.0087, 0.0078, 0.0708, 0.0656, 0.0034, 0.0048, 0.0222, 0.0084, 0.0038, 0.0186, 0.0149, 0.0235

data:
dataset:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 1
hf_dataset_config_name: null

hf_dataset_or_datasets: project_data/doremi/datasets/the_pile_raw/tokenized_data/train

num_loading_workers: 1
seed: 42
general:
benchmark_csv_path: null
consumed_train_samples: null
ignore_sanity_checks: true
project: nanotron
run: train_tuned_2.8b_model
seed: 42
step: null
logging:
iteration_step_info_interval: 1
log_level: info
log_level_replica: info
model:
ddp_bucket_cap_mb: 120
dtype: bfloat16
init_method:
std: 0.025
make_vocab_size_divisible_by: 1
model_config:
bos_token_id: 1
eos_token_id: 2
hidden_act: silu
hidden_size: 4096
initializer_range: 0.02
intermediate_size: 24576
is_llama_config: true
max_position_embeddings: 256
num_attention_heads: 32
# num_hidden_layers: 40
num_hidden_layers: 6
num_key_value_heads: 16
pad_token_id: null
pretraining_tp: 1
rms_norm_eps: 1.0e-05
rope_scaling: null
tie_word_embeddings: true
use_cache: true
vocab_size: 49152
optimizer:
accumulate_grad_in_fp32: true
adam_beta1: 0.9
adam_beta2: 0.95
adam_eps: 1.0e-08
clip_grad: 1.0
learning_rate_scheduler:
learning_rate: 0.0003
lr_decay_steps: 8
lr_decay_style: cosine
lr_warmup_steps: 2
lr_warmup_style: linear
min_decay_lr: 1.0e-05
torch_adam_is_fused: true
weight_decay: 0.01
zero_stage: 0
parallelism:
# dp: 8
# pp: 1
# tp: 8
# tp: 2

# NOTE: for running eval
dp: 1
pp: 1
tp: 8
pp_engine: 1f1b
tp_linear_async_communication: true
tp_mode: REDUCE_SCATTER
profiler: null
tokenizer:
tokenizer_max_length: null
tokenizer_name_or_path: gpt2
tokenizer_revision: null
tokens:
# batch_accumulation_per_replica * micro_batch_size * dp = 4 * 8 * 16 = 512
# batch_accumulation_per_replica * micro_batch_size * dp = 16 * 8 * 4 = 512
# batch_accumulation_per_replica * micro_batch_size * dp = 8 * 8 * 8 = 512 (this one)
# 240 * 1024 = 245760
# the doremi paper do 500k tokens per batch
# batch_accumulation_per_replica: 16

# NOTE: some weird bug, where if you run batch_accumulation_per_replica=16
# it results no samples from some domains

# NOTE: this causes some domain losses are 0
# batch_accumulation_per_replica: 8
# micro_batch_size: 8

batch_accumulation_per_replica: 1
micro_batch_size: 64

limit_test_batches: 0
limit_val_batches: 1
sequence_length: 1024
# train_steps: 1000
# train_steps: 70_000
# train_steps: 70_000
train_steps: 70_010
val_check_interval: -1
Loading

0 comments on commit 9f9af42

Please sign in to comment.