Skip to content

Commit

Permalink
add full functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Jul 12, 2024
1 parent 6400e86 commit 01ea52a
Show file tree
Hide file tree
Showing 29 changed files with 641 additions and 230 deletions.
25 changes: 18 additions & 7 deletions examples/fp8/ablations/configs/sanity_bf16.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
checkpoints:
checkpoint_interval: 50000
checkpoint_interval: 100
checkpoints_path: checkpoints
checkpoints_path_is_shared_file_system: false
# resume_checkpoint_path: checkpoints
Expand All @@ -20,9 +20,9 @@ data_stages:
general:
benchmark_csv_path: null
consumed_train_samples: null
ignore_sanity_checks: true
ignore_sanity_checks: false
project: fp8_for_nanotron
run: bfloat16_2_layers_and_seq_len_256_and_micro_batch_128_and_lr_2.0e-4_and_minipile_overfitting_and_fp8_branch_and_layernorm_and_custom_adam_and_tp_1_and_no_weight_decay
run: bfloat16_2_layers_and_seq_len_256_and_micro_batch_128_and_lr_0.01_and_minipile_overfitting_and_fp8_branch_and_layernorm_and_custom_adam_and_tp_1_and_no_weight_decay_after_fixing_init_without_zero_grad
seed: 42
step: null
lighteval: null
Expand Down Expand Up @@ -63,21 +63,32 @@ model:
optimizer:
accumulate_grad_in_fp32: false
learning_rate_scheduler:
learning_rate: 0.0006
learning_rate: 0.0004
lr_decay_starting_step: null
lr_decay_steps: null
lr_decay_steps: 800
lr_decay_style: cosine
lr_warmup_steps: 0 # 10% warm up of total training steps
lr_warmup_steps: 200 # 10% warm up of total training steps
lr_warmup_style: linear
min_decay_lr: 0.00006

# learning_rate_scheduler:
# learning_rate: 0.01
# lr_decay_starting_step: null
# lr_decay_steps: null
# lr_decay_style: cosine
# lr_warmup_steps: 0 # 10% warm up of total training steps
# lr_warmup_style: linear
# min_decay_lr: 0.00006

optimizer_factory:
adam_beta1: 0.9
adam_beta2: 0.95
adam_eps: 1.0e-08
name: custom_adam
torch_adam_is_fused: true
weight_decay: 0.
weight_decay: 0.1
zero_stage: 0
clip_grad: 1.0

parallelism:
dp: 1
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
checkpoints:
checkpoint_interval: 50000
checkpoint_interval: 100
checkpoints_path: checkpoints
checkpoints_path_is_shared_file_system: false
# resume_checkpoint_path: checkpoints
Expand All @@ -22,7 +22,7 @@ general:
consumed_train_samples: null
ignore_sanity_checks: true
project: fp8_for_nanotron
run: bfloat16_2_layers_and_seq_len_256_and_micro_batch_128_and_lr_2.0e-4_and_minipile_overfitting_and_main_branch_and_layernorm_and_no_weight_decay_and_no_warmup
run: bfloat16_2_layers_and_seq_len_256_and_micro_batch_128_and_lr_2.0e-4_and_minipile_overfitting_and_main_branch_and_layernorm_and_no_weight_decay_and_no_warmup_and_without_zerograd
seed: 42
step: null
lighteval: null
Expand Down Expand Up @@ -92,13 +92,12 @@ model:

optimizer:
accumulate_grad_in_fp32: false
# clip_grad: 1.0
learning_rate_scheduler:
learning_rate: 0.0006
lr_decay_starting_step: null
lr_decay_steps: null
lr_decay_steps: 800
lr_decay_style: cosine
lr_warmup_steps: 0 # 10% warm up of total training steps
lr_warmup_steps: 200 # 10% warm up of total training steps
lr_warmup_style: linear
min_decay_lr: 0.00006
optimizer_factory:
Expand Down
15 changes: 15 additions & 0 deletions src/nanotron/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,18 @@
PARAM_ID_TO_PARAM_NAMES = None

ITERATION_STEP: int = 1

DEBUG_PATH = "/fsx/phuc/temp/temp3_env_for_fp8/nanotron/debug/runs"
DEBUG_SAVE_PATH = "/fsx/phuc/temp/temp3_env_for_fp8/nanotron/debug/runs/{}/{}"


def get_debug_save_path(name, iteration_step):
import os

path = DEBUG_SAVE_PATH.format(name, iteration_step)
if not os.path.exists(path):
os.makedirs(path)
return path


is_ready_to_log = False
15 changes: 14 additions & 1 deletion src/nanotron/fp8/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,22 @@
import torch.distributed as dist
from torch.distributed import * # noqa

from nanotron.distributed import *
from nanotron.fp8.parameter import FP8Parameter
from nanotron.fp8.tensor import FP8Tensor, convert_tensor_from_fp8
from nanotron.parallel.parameters import NanotronParameter
from nanotron.parallel.parameters import NanotronParameter, get_data_from_param


def all_reduce(
tensor: Union[torch.Tensor, NanotronParameter],
op: dist.ReduceOp = dist.ReduceOp.SUM,
group: Optional[dist.ProcessGroup] = None,
async_op: bool = False,
):
assert tensor.__class__ in [torch.Tensor, NanotronParameter]
data = get_data_from_param(tensor) if tensor.__class__ == NanotronParameter else tensor

dist.all_reduce(data, op=op, group=group, async_op=async_op)


def all_gather(
Expand Down
10 changes: 9 additions & 1 deletion src/nanotron/fp8/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from nanotron.fp8.parameter import FP8Parameter
from nanotron.fp8.recipe import FP8LinearRecipe
from nanotron.fp8.tensor import FP8Tensor
from nanotron.parallel.parameters import get_data_from_param


@dataclass
Expand Down Expand Up @@ -68,8 +69,15 @@ def __init__(
def forward(self, input: Union[FP8Tensor, torch.Tensor]) -> torch.Tensor:
import nanotron.fp8.functional as F

# return F.linear(
# input=input, weight=self.weight.data, bias=self.bias, metadatas=self.metadatas, recipe=self.recipe
# )
return F.linear(
input=input, weight=self.weight.data, bias=self.bias, metadatas=self.metadatas, recipe=self.recipe
input=input,
weight=get_data_from_param(self.weight),
bias=get_data_from_param(self.bias),
metadatas=self.metadatas,
recipe=self.recipe,
)

def __repr__(self) -> str:
Expand Down
82 changes: 58 additions & 24 deletions src/nanotron/fp8/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,15 @@
convert_tensor_from_fp16,
)
from nanotron.fp8.utils import compute_stas, is_overflow_underflow_nan
from nanotron.parallel.parameters import NanotronParameter
from nanotron.parallel.parameters import (
NanotronParameter,
get_data_from_param,
get_data_from_sliced_or_param,
get_grad_from_parameter,
get_grad_from_sliced_or_param,
set_data_for_sliced_or_param,
set_grad_none_for_sliced_or_param,
)


class Adam(Optimizer):
Expand Down Expand Up @@ -74,7 +82,8 @@ def step(self, closure=None):
for group in self.param_groups:
for p in group["params"]:
state = self.state[p]
data = p.data
# data = p.data
data = get_data_from_param(p)
assert isinstance(data, torch.Tensor)

if len(state) == 0:
Expand All @@ -84,8 +93,23 @@ def step(self, closure=None):

loggings[p] = {}

assert (p.grad is not None and p.data.grad is not None) is False
grad = p.grad if p.grad is not None else p.data.grad
# assert (p.grad is not None and p.data.grad is not None) is False
# grad = p.grad if p.grad is not None else p.data.grad
# grad = get_grad_from_parameter(p)
grad = get_grad_from_sliced_or_param(p)

from nanotron import constants
from nanotron.constants import get_debug_save_path

# debug_save_path = constants.DEBUG_SAVE_PATH.format(constants.CONFIG.general.run, constants.ITERATION_STEP)
debug_save_path = get_debug_save_path(constants.CONFIG.general.run, constants.ITERATION_STEP)

if constants.is_ready_to_log is True and constants.CONFIG.logging.monitor_model_states is True:
torch.save(grad, f"{debug_save_path}/{self.params_id_to_param_names[id(p)]}_before_update_grad.pt")
torch.save(
data, f"{debug_save_path}/{self.params_id_to_param_names[id(p)]}_before_update_weight.pt"
)

assert isinstance(grad, torch.Tensor)

exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
Expand All @@ -110,16 +134,19 @@ def step(self, closure=None):
exp_avg_sq = exp_avg_sq / bias_correction2

# denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group["eps"])
denom = exp_avg_sq.sqrt() + group["eps"]
denom = (exp_avg_sq + group["eps"]).sqrt()
normalized_grad = exp_avg / denom

lr = group["lr"]
# p.data.addcdiv_(-step_size, exp_avg, denom)
new_data = data - lr * normalized_grad
new_data = data - lr * (normalized_grad + (group["weight_decay"] * data))
new_data.requires_grad = True
p.data = new_data

assert p.data is new_data
# p.data = new_data
# assert p.data is new_data

set_data_for_sliced_or_param(p, new_data)
assert get_data_from_sliced_or_param(p) is new_data

state["exp_avg"] = exp_avg
state["exp_avg_sq"] = exp_avg_sq
Expand All @@ -146,11 +173,13 @@ def step(self, closure=None):
def zero_grad(self):
for group in self.param_groups:
for p in group["params"]:
if p.grad is not None:
p.grad = None
# if p.grad is not None:
# p.grad = None

if p.data.grad is not None:
p.data.grad = None
# if p.data.grad is not None:
# p.data.grad = None
# set_grad_none_for_param(p)
set_grad_none_for_sliced_or_param(p)

assert p.grad is None
assert p.data.grad is None
Expand Down Expand Up @@ -308,12 +337,15 @@ def step(self):
num_param_has_grads = 0
for g in self.param_groups:
for p in g["params"]:
if p.data.__class__ == torch.Tensor:
if p.grad is not None:
num_param_has_grads += 1
elif p.data.__class__ == FP8Tensor:
if hasattr(p.data, "_temp_grad") and p.data._temp_grad is not None:
num_param_has_grads += 1
# if p.data.__class__ == torch.Tensor:
# if p.grad is not None:
# num_param_has_grads += 1
# elif p.data.__class__ == FP8Tensor:
# if hasattr(p.data, "_temp_grad") and p.data._temp_grad is not None:
# num_param_has_grads += 1
grad = get_grad_from_parameter(p)
if p is not None:
num_param_has_grads += 1

assert num_param_has_grads > 0

Expand Down Expand Up @@ -341,12 +373,14 @@ def step(self):
fp32_data = fp16_data.to(torch.float32)
# NOTE: the bias of FP8 parameter saves its gradient in p.data.grad
# and the weight, and bias of non-FP8 parameter saves its gradient in p.grad
try:
assert (p.data.grad is None and p.grad is None) is False
except:
assert 1 == 1

grad = p.data.grad if p.data.grad is not None else p.grad
# try:
# assert (p.data.grad is None and p.grad is None) is False
# except:
# assert 1 == 1
# assert (p.data.grad is None and p.grad is None) is False

# grad = p.data.grad if p.data.grad is not None else p.grad
grad = get_grad_from_parameter(p)
fp32_grad = grad.to(torch.float32)

if p.__class__ == NanotronParameter:
Expand Down
Loading

0 comments on commit 01ea52a

Please sign in to comment.