Skip to content

Commit

Permalink
remove fp8 tp from llama's modeling code, fix no grad in param, remov…
Browse files Browse the repository at this point in the history
…e rms norm due to illegal memory
  • Loading branch information
xrsrke committed Nov 20, 2024
1 parent 9510f57 commit 478984a
Show file tree
Hide file tree
Showing 7 changed files with 226 additions and 59 deletions.
109 changes: 109 additions & 0 deletions examples/config_tiny_fp8_llama.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
checkpoints:
checkpoint_interval: 10
checkpoints_path: checkpoints
checkpoints_path_is_shared_file_system: false
resume_checkpoint_path: null
save_initial_state: false
data_stages:
- data:
dataset:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 1
hf_dataset_config_name: null
hf_dataset_or_datasets: stas/openwebtext-10k
hf_dataset_splits: train
text_column_name: text
num_loading_workers: 1
seed: 42
name: Stable Training Stage
start_training_step: 1
- data:
dataset:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 1
hf_dataset_config_name: null
hf_dataset_or_datasets: stas/openwebtext-10k
hf_dataset_splits: train
text_column_name: text
num_loading_workers: 1
seed: 42
name: Annealing Phase
start_training_step: 10
general:
benchmark_csv_path: null
consumed_train_samples: null
ignore_sanity_checks: true
project: debug
run: tiny_llama_%date_%jobid
seed: 42
step: null
lighteval: null
logging:
iteration_step_info_interval: 1
log_level: info
log_level_replica: info
model:
ddp_bucket_cap_mb: 25
dtype: bfloat16
init_method:
std: 0.025
make_vocab_size_divisible_by: 1
model_config:
bos_token_id: 1
eos_token_id: 2
hidden_act: silu
hidden_size: 16
initializer_range: 0.02
intermediate_size: 64
is_llama_config: true
max_position_embeddings: 256
num_attention_heads: 4
num_hidden_layers: 2
num_key_value_heads: 4
pad_token_id: null
pretraining_tp: 1
rms_norm_eps: 1.0e-05
rope_scaling: null
tie_word_embeddings: true
use_cache: true
vocab_size: 256
optimizer:
accumulate_grad_in_fp32: true
clip_grad: 1.0
learning_rate_scheduler:
learning_rate: 0.0003
lr_decay_starting_step: null
lr_decay_steps: 13
lr_decay_style: cosine
lr_warmup_steps: 2
lr_warmup_style: linear
min_decay_lr: 1.0e-05
optimizer_factory:
adam_beta1: 0.9
adam_beta2: 0.95
adam_eps: 1.0e-08
name: adamW
torch_adam_is_fused: true
weight_decay: 0.01
zero_stage: 0
parallelism:
dp: 2
expert_parallel_size: 1
pp: 2
pp_engine: 1f1b
tp: 2
tp_linear_async_communication: true
tp_mode: REDUCE_SCATTER
profiler: null
tokenizer:
tokenizer_max_length: null
tokenizer_name_or_path: robot-test/dummy-tokenizer-wordlevel
tokenizer_revision: null
tokens:
batch_accumulation_per_replica: 1
limit_test_batches: 0
limit_val_batches: 0
micro_batch_size: 2
sequence_length: 256
train_steps: 15
val_check_interval: -1
52 changes: 26 additions & 26 deletions examples/config_tiny_llama.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
checkpoints:
checkpoint_interval: 10
checkpoint_interval: 10000
checkpoints_path: checkpoints
checkpoints_path_is_shared_file_system: false
resume_checkpoint_path: null
Expand All @@ -10,25 +10,25 @@ data_stages:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 1
hf_dataset_config_name: null
hf_dataset_or_datasets: stas/openwebtext-10k
hf_dataset_or_datasets: roneneldan/TinyStories
hf_dataset_splits: train
text_column_name: text
num_loading_workers: 1
seed: 42
name: Stable Training Stage
start_training_step: 1
- data:
dataset:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 1
hf_dataset_config_name: null
hf_dataset_or_datasets: stas/openwebtext-10k
hf_dataset_splits: train
text_column_name: text
num_loading_workers: 1
seed: 42
name: Annealing Phase
start_training_step: 10
# - data:
# dataset:
# dataset_overwrite_cache: false
# dataset_processing_num_proc_per_process: 1
# hf_dataset_config_name: null
# hf_dataset_or_datasets: stas/openwebtext-10k
# hf_dataset_splits: train
# text_column_name: text
# num_loading_workers: 1
# seed: 42
# name: Annealing Phase
# start_training_step: 10
general:
benchmark_csv_path: null
consumed_train_samples: null
Expand All @@ -44,29 +44,29 @@ logging:
log_level_replica: info
model:
ddp_bucket_cap_mb: 25
dtype: float8
dtype: bfloat16
init_method:
std: 0.025
make_vocab_size_divisible_by: 1
model_config:
bos_token_id: 1
eos_token_id: 2
hidden_act: silu
hidden_size: 16
hidden_size: 1024
initializer_range: 0.02
intermediate_size: 64
intermediate_size: 4096
is_llama_config: true
max_position_embeddings: 256
max_position_embeddings: 1024
num_attention_heads: 4
num_hidden_layers: 2
num_hidden_layers: 6
num_key_value_heads: 4
pad_token_id: null
pretraining_tp: 1
rms_norm_eps: 1.0e-05
rope_scaling: null
tie_word_embeddings: true
use_cache: true
vocab_size: 256
vocab_size: 1024
optimizer:
accumulate_grad_in_fp32: true
clip_grad: 1.0
Expand All @@ -87,13 +87,13 @@ optimizer:
weight_decay: 0.01
zero_stage: 0
parallelism:
dp: 2
dp: 1
expert_parallel_size: 1
pp: 2
pp: 1
pp_engine: 1f1b
tp: 2
tp_linear_async_communication: true
tp_mode: REDUCE_SCATTER
tp_linear_async_communication: false
tp_mode: ALL_REDUCE
profiler: null
tokenizer:
tokenizer_max_length: null
Expand All @@ -104,6 +104,6 @@ tokens:
limit_test_batches: 0
limit_val_batches: 0
micro_batch_size: 2
sequence_length: 256
train_steps: 15
sequence_length: 1024
train_steps: 1500
val_check_interval: -1
46 changes: 24 additions & 22 deletions src/nanotron/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,16 @@
from nanotron.logging import log_rank
from nanotron.models import NanotronModel
from nanotron.nn.activations import ACT2FN
from nanotron.nn.layer_norm import TritonRMSNorm
from nanotron.parallel import ParallelContext
from nanotron.parallel.parameters import NanotronParameter
from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer
from nanotron.parallel.pipeline_parallel.p2p import P2P
from nanotron.parallel.tensor_parallel.functional import sharded_cross_entropy
from nanotron.parallel.tensor_parallel.nn import (
FP8TensorParallelColumnLinear,
FP8TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelLinearMode,
TensorParallelRowLinear,
)
from nanotron.random import RandomStates
from nanotron.scaling.parametrization import SpectralMupParametrizator, StandardParametrizator
Expand Down Expand Up @@ -222,26 +221,24 @@ def __init__(
config.intermediate_size, # shape of up_linear
)
# self.gate_up_proj = TensorParallelColumnLinear(
self.gate_up_proj = FP8TensorParallelColumnLinear(
self.gate_up_proj = TensorParallelColumnLinear(
config.hidden_size,
2 * config.intermediate_size,
pg=tp_pg,
mode=tp_mode,
bias=False,
async_communication=tp_linear_async_communication,
contiguous_chunks=gate_up_contiguous_chunks,
name=f"model.decoder.{layer_idx}.mlp.gate_up_proj",
# name=f"model.decoder.{layer_idx}.mlp.gate_up_proj",
# tp_recompute_allgather=parallel_config.tp_recompute_allgather,
)
# self.down_proj = TensorParallelRowLinear(
self.down_proj = FP8TensorParallelRowLinear(
self.down_proj = TensorParallelRowLinear(
config.intermediate_size,
config.hidden_size,
pg=tp_pg,
mode=tp_mode,
bias=False,
async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER,
name=f"model.decoder.{layer_idx}.mlp.down_proj",
)
self.split_silu_mul = GLUActivation(config.hidden_act)

Expand Down Expand Up @@ -386,16 +383,16 @@ def __init__(
config.num_key_value_heads * self.d_qk, # shape of k
config.num_key_value_heads * self.d_qk, # shape of v
)
# self.qkv_proj = TensorParallelColumnLinear(
self.qkv_proj = FP8TensorParallelColumnLinear(
# self.qkv_proj = FP8TensorParallelColumnLinear(
self.qkv_proj = TensorParallelColumnLinear(
self.d_model,
config.num_attention_heads * self.d_qk + 2 * config.num_key_value_heads * self.d_qk,
pg=tp_pg,
mode=tp_mode,
bias=False,
async_communication=tp_linear_async_communication,
contiguous_chunks=qkv_contiguous_chunks,
name=f"model.decoder.{layer_idx}.attention.qkv_proj",
# name=f"model.decoder.{layer_idx}.attention.qkv_proj",
# tp_recompute_allgather=parallel_config.tp_recompute_allgather,
)
# TODO(kunhao): We want to have only one version per device and not one version per layer.
Expand All @@ -418,15 +415,14 @@ def __init__(
dim=self.d_qk, base=config.rope_theta, interleaved=config.rope_interleaved
)

# self.o_proj = TensorParallelRowLinear(
self.o_proj = FP8TensorParallelRowLinear(
self.o_proj = TensorParallelRowLinear(
config.num_attention_heads * self.d_qk,
self.d_model,
pg=tp_pg,
mode=tp_mode,
bias=False,
async_communication=tp_linear_async_communication,
name=f"model.decoder.{layer_idx}.attention.o_proj",
# name=f"model.decoder.{layer_idx}.attention.o_proj",
)

self.attention = CoreAttention(
Expand Down Expand Up @@ -710,15 +706,19 @@ def __init__(
layer_idx: int,
):
super().__init__()
self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# NOTE: i got an illegal memory access was encountered when using TritonRMSNorm
# even downgrad flash_attn to 2.4.2
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
self.attn = CausalSelfAttention(
config=config,
parallel_config=parallel_config,
tp_pg=tp_pg,
layer_idx=layer_idx,
)

self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg, layer_idx=layer_idx)

self.recompute_layer = parallel_config.recompute_layer
Expand Down Expand Up @@ -856,17 +856,19 @@ def __init__(

self.final_layer_norm = PipelineBlock(
p2p=self.p2p,
module_builder=TritonRMSNorm,
module_kwargs={"hidden_size": config.hidden_size, "eps": config.rms_norm_eps},
# module_builder=TritonRMSNorm,
# module_kwargs={"hidden_size": config.hidden_size, "eps": config.rms_norm_eps},
module_builder=nn.LayerNorm,
module_kwargs={"normalized_shape": config.hidden_size, "eps": config.rms_norm_eps},
module_input_keys={"input"},
module_output_keys={"hidden_states"},
) # TODO

self.lm_head = PipelineBlock(
p2p=self.p2p,
# Understand that this means that we return sharded logits that are going to need to be gathered
# module_builder=TensorParallelColumnLinear,
module_builder=FP8TensorParallelColumnLinear,
# module_builder=FP8TensorParallelColumnLinear,
module_builder=TensorParallelColumnLinear,
module_kwargs={
"in_features": config.hidden_size,
"out_features": config.vocab_size,
Expand Down Expand Up @@ -930,8 +932,8 @@ def get_block_compute_costs(self):
LlamaDecoderLayer: 4 * model_config.num_attention_heads * d_qkv * model_config.hidden_size
+ 3 * d_ff * model_config.hidden_size,
# This is the last lm_head
# TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size,
FP8TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size,
TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size,
# FP8TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size,
}
return block_compute_costs

Expand Down
Loading

0 comments on commit 478984a

Please sign in to comment.