Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ring attention #181

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions examples/llama/tests/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _test_nt_to_hf(parallel_context: ParallelContext, input_ids: torch.Tensor):


def test_nt_to_hf(input_ids: torch.Tensor):
init_distributed(tp=1, dp=1, pp=1)(_test_nt_to_hf)(input_ids=input_ids)
init_distributed(tp=1, dp=1, pp=1, sp=1)(_test_nt_to_hf)(input_ids=input_ids)


def _test_nt_to_hf_with_files(parallel_context: ParallelContext, input_ids: torch.Tensor, test_context: TestContext):
Expand All @@ -130,7 +130,9 @@ def _test_nt_to_hf_with_files(parallel_context: ParallelContext, input_ids: torc


def test_nt_to_hf_with_files(input_ids: torch.Tensor):
init_distributed(tp=1, dp=1, pp=1)(_test_nt_to_hf_with_files)(input_ids=input_ids, test_context=TestContext())
init_distributed(tp=1, dp=1, pp=1, sp=1)(_test_nt_to_hf_with_files)(
input_ids=input_ids, test_context=TestContext()
)


def _test_hf_to_nt(parallel_context: ParallelContext, input_ids: torch.Tensor):
Expand All @@ -141,11 +143,11 @@ def _test_hf_to_nt(parallel_context: ParallelContext, input_ids: torch.Tensor):
logits_nt = model_nt.model(input_ids, input_mask).permute(1, 0, 2)
logits_hf = model_hf(input_ids).logits
assert logits_nt.size() == logits_hf.size()
torch.testing.assert_allclose(logits_hf, logits_nt, atol=ATOL)
torch.testing.assert_allclose(logits_hf, logits_nt, atol=ATOL)


def test_hf_to_nt(input_ids: torch.Tensor):
init_distributed(tp=1, dp=1, pp=1)(_test_hf_to_nt)(input_ids=input_ids)
init_distributed(tp=1, dp=1, pp=1, sp=1)(_test_hf_to_nt)(input_ids=input_ids)


def _test_hf_to_nt_with_files(parallel_context: ParallelContext, input_ids: torch.Tensor, test_context: TestContext):
Expand All @@ -168,7 +170,9 @@ def _test_hf_to_nt_with_files(parallel_context: ParallelContext, input_ids: torc


def test_hf_to_nt_with_files(input_ids: torch.Tensor):
init_distributed(tp=1, dp=1, pp=1)(_test_hf_to_nt_with_files)(input_ids=input_ids, test_context=TestContext())
init_distributed(tp=1, dp=1, pp=1, sp=1)(_test_hf_to_nt_with_files)(
input_ids=input_ids, test_context=TestContext()
)


def _test_composed_conversion(parallel_context: ParallelContext):
Expand Down Expand Up @@ -196,7 +200,7 @@ def _test_composed_conversion(parallel_context: ParallelContext):


def test_composed_conversion():
init_distributed(tp=1, dp=1, pp=1)(_test_composed_conversion)()
init_distributed(tp=1, dp=1, pp=1, sp=1)(_test_composed_conversion)()


def _save_parallel_nanotron(parallel_context: ParallelContext, input_ids: torch.Tensor, nt_path: Path):
Expand Down Expand Up @@ -239,9 +243,11 @@ def test_tensor_parallel_conversion(input_ids: torch.Tensor):
hf_path = root / "nanotron"

# Launch both parts.
init_distributed(tp=2, dp=1, pp=1)(_save_parallel_nanotron)(input_ids=input_ids, nt_path=nt_path)
init_distributed(tp=2, dp=1, pp=1, sp=1)(_save_parallel_nanotron)(input_ids=input_ids, nt_path=nt_path)
assert (nt_path / "logits.pt").exists()
init_distributed(tp=1, dp=1, pp=1)(_convert_from_parallel)(input_ids=input_ids, nt_path=nt_path, hf_path=hf_path)
init_distributed(tp=1, dp=1, pp=1, sp=1)(_convert_from_parallel)(
input_ids=input_ids, nt_path=nt_path, hf_path=hf_path
)
assert (hf_path / "logits.pt").exists()

# Load logits and verify they match.
Expand Down
34 changes: 34 additions & 0 deletions examples/long-context/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Long Context with Sequence Parallelism

## Sequence Parallelism
Sequence parallelism is a way to split the activation value and thus reduce the memory footprint. If your input is long enough (e.g., sequence length > 16K), then you may need to perform sequence parallelism.

Nanotron implements sequence parallelism via ring attention. To use this feature, you only need to set *sp* > 1. However, noting that the TP-Reduce scatter mode not only splits the model's weights/gradients/optimizer states but also the activation value, I suggest trying to increase the TP to 8 first, and then if activation value splitting is still needed, increase the SP.

## Experiment Setup
Our experimental results show that when training with long sequences (e.g., scaling from 8K to 128K), gradually increasing the sequence length and the size of the RoPE theta speeds up convergence and saves computational resources.

An example of this is when extending the context length from 8K to 1M, this is how I increased the context and RoPE base:

| Sequence Length | Rope Theta | Train Steps | Batch Accumulation per Replica | Micro Batch Size |
|-----------------|-----------------|-------------|-------------------------------|------------------|
| 65536 | 22,400,000.0 | 10 | 32 | 1 |
| 131072 | 80,000,000.0 | 10 | 32 | 1 |
| 524288 | 1,000,000,000.0 | 10 | 16 | 1 |
| 1048576 | 3,600,000,000.0 | 10 | 8 | 1 |

For this reason, I added a template file in this folder. You can customize the hyper-parameters for each training and then generate the config files by running: `python create_config.py`

Then for training, you only need to execute: `sbatch launch_training.slurm`

## Experiment results
We successfully extended the model's context length to 1M tokens using less than 0.3B tokens and 50 training steps, achieving 100% accuracy in the needle-in-a-haystack experiment.
![1M Context Length Example](./images/1M.png)

To demonstrate how quickly the model can adapt to long contexts, we created a GIF that shows how the model completes a needle-in-a-haystack test over a 64K context length in 10 steps of training.

![64K revolution gif](./images/theta=22p4M.gif)

The setting of the base value is crucial. An example of incorrect setting is setting the base value to 1M and then training for 15 steps. The resulting model can hardly complete the needle-in-a-haystack task.

![ablation study](./images/theta=1M_15steps.jpeg)
127 changes: 127 additions & 0 deletions examples/long-context/create_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import os

from jinja2 import Environment, FileSystemLoader

# Choose the template
template_file = "llama3_template.jinja2"

# Load the template environment
current_directory = os.getcwd()
template_directory = os.path.join(current_directory, "templates")

env = Environment(loader=FileSystemLoader(template_directory))
template = env.get_template(template_file)
print("Template file: {}".format(template_file))
print()


###### First create a temple file ######
# You should define the model/tokenizer/dataset in the template file
# You can use llama3 converter to get the weights and tokenizer
tokenizer_path = "/fsx/haojun/lighteval_evaluation_model/NanotronLlama3-8B-Instruct" # replace with your own
init_weights_path = "/fsx/haojun/lighteval_evaluation_model/NanotronLlama3-8B-Instruct" # replace with your own
dataset_folder = "/fsx/haojun/datasets/tokenized_bytes_4B_tokens" # replace with your own
###### end ######


############ hyper-parameter for experiments ############
experiment_name = "1M_4stages"
sequence_lengths = [65536, 131072, 524288, 1048576] # Model sequence length
rope_thetas = [22400000.0, 80000000.0, 1000000000.0, 3600000000.0] # base value of RoPE
train_steps = [10, 20, 30, 40] # accumulative steps
batch_accumulation_per_replicas = [32, 32, 16, 8] # gradient accumulation steps
micro_batch_sizes = [1, 1, 1, 1] # batch size
sps = [4, 8, 32, 64] # Sequence parallelism degree
tp = 8 # Tensor parallelism degree
checkpoint_intervals = [1, 1, 1, 1]

############ end ############

############ optimizer ############
lr_warmup_steps = 1
lr_decay_steps = 1
learning_rate = 0.00002
min_decay_lr = 0.00002
############ end ############


############ checkpoints/config path ############
# model weights output directory
checkpoints_path = os.path.join(current_directory, "weights", experiment_name)
checkpoints_paths = [checkpoints_path] * len(sequence_lengths)
resume_checkpoint_paths = ["null"] + [checkpoints_path] * (len(sequence_lengths) - 1)

# Config files output directory
output_dir = os.path.join(current_directory, "configs", experiment_name)
if not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
print(f"Created config directory: {output_dir}")
############ end ############

# Ensure that we have exactly same number elements in each list to match the requirement
list_lengths = [
len(checkpoints_paths),
len(resume_checkpoint_paths),
len(sequence_lengths),
len(rope_thetas),
len(train_steps),
len(batch_accumulation_per_replicas),
len(micro_batch_sizes),
len(sps),
len(checkpoint_intervals),
]
if not all(length == list_lengths[0] for length in list_lengths):
raise ValueError("All input lists must have the same length.")


def format_float(value, decimal_places=5):
return f"{value:.{decimal_places}f}"


for i in range(len(checkpoints_paths)):
checkpoints_path = checkpoints_paths[i]
resume_checkpoint_path = resume_checkpoint_paths[i]
checkpoint_interval = checkpoint_intervals[i]
batch_accumulation_per_replica = batch_accumulation_per_replicas[i]
sequence_length = sequence_lengths[i]
rope_theta = rope_thetas[i]
train_step = train_steps[i]
micro_batch_size = micro_batch_sizes[i]
sp = sps[i]

# Create the checkpoint path if it doesn't exist
if not os.path.exists(checkpoints_path):
os.makedirs(checkpoints_path, exist_ok=True)
print(f"Created directory: {checkpoints_path}")

variables = {
"tokenizer_path": tokenizer_path,
"init_weights_path": init_weights_path,
"dataset_folder": dataset_folder,
"checkpoints_path": checkpoints_path,
"resume_checkpoint_path": resume_checkpoint_path,
"checkpoint_interval": checkpoint_interval,
"sequence_length": sequence_length,
"rope_theta": rope_theta,
"train_steps": train_step,
"batch_accumulation_per_replica": batch_accumulation_per_replica,
"micro_batch_size": micro_batch_size,
"learning_rate": format_float(learning_rate),
"min_decay_lr": format_float(min_decay_lr),
"lr_warmup_steps": lr_warmup_steps,
"lr_decay_steps": lr_decay_steps,
"sp": sp,
"tp": tp,
}

# Render the template with the provided variables
config = template.render(variables)

# # Define the output file name
output_file = f"{output_dir}/config_{i}_theta={rope_theta/1e6}M_steps={train_step}_seq_len={sequence_length}.yaml"

# Save the rendered configuration to a YAML file
with open(output_file, "w") as f:
f.write(config)

print(f"Configuration file '{output_file}' has been generated.")
Binary file added examples/long-context/images/1M.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/long-context/images/theta=22p4M.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
53 changes: 53 additions & 0 deletions examples/long-context/launch_training.slurm
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#!/bin/bash
#SBATCH --job-name=long-context
#SBATCH --qos=high
#SBATCH --nodes=4
#SBATCH --ntasks-per-node=1
#SBATCH --partition=hopper-prod
#SBATCH --gres=gpu:8
#SBATCH --mem=0

# HF stuff
if [ -f .env ]; then
export $(cat .env | xargs)
fi

# script for a benchmark
set -x -e

# replace with your own env
echo "START TIME: $(date)"
source /admin/home/haojun_zhao/miniconda3/etc/profile.d/conda.sh
conda activate /admin/home/haojun_zhao/miniconda3/envs/nt_mamba
echo python3 version = $(python3 --version)

# SLURM stuff
export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"`
export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
export MASTER_PORT=9001
export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l`
export OMP_NUM_THREADS=8
export CUDA_DEVICE_MAX_CONNECTIONS=1
export FI_PROVIDER="efa"
NUM_GPUS=$(nvidia-smi --list-gpus | wc -l)

echo "NCCL version: $(python -c "import torch;print(torch.cuda.nccl.version())")"
echo "CUDA version: $(python -c "import torch;print(torch.version.cuda)")"

module load cuda/12.1

current_dir=$(pwd)
parent_of_parent_dir=$(dirname $(dirname "$current_dir"))

TRAIN_SCRIPT= "$parent_of_parent_dir/run_train.py"
CONFIG_FILE= "$current_dir/1M_4stages/config_0_theta=22.4M_steps=10_seq_len=65536.yaml"

srun torchrun \
--nnodes=4 \
--nproc_per_node=8 \
--rdzv_backend=c10d \
--rdzv-id ${SLURM_JOB_ID} \
--rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \
--max_restarts=0 \
--tee=3 \
$TRAIN_SCRIPT --config-file $CONFIG_FILE
Loading
Loading