Skip to content

Commit

Permalink
clean custom fp8_optim
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Jul 10, 2024
1 parent 1068f89 commit 6400e86
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 104 deletions.
10 changes: 5 additions & 5 deletions examples/fp8/ablations/configs/sanity_bf16.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ general:
consumed_train_samples: null
ignore_sanity_checks: true
project: fp8_for_nanotron
run: bfloat16_2_layers_and_seq_len_256_and_micro_batch_128_and_lr_2.0e-4_and_minipile_overfitting_and_fp8_branch_and_layernorm_and_custom_adam_and_tp_1_and_lr_1.0
run: bfloat16_2_layers_and_seq_len_256_and_micro_batch_128_and_lr_2.0e-4_and_minipile_overfitting_and_fp8_branch_and_layernorm_and_custom_adam_and_tp_1_and_no_weight_decay
seed: 42
step: null
lighteval: null
logging:
iteration_step_info_interval: 1
log_level: info
log_level_replica: info
monitor_model_states: false
monitor_model_states: true
model:
ddp_bucket_cap_mb: 25
dtype: bfloat16
Expand Down Expand Up @@ -63,11 +63,11 @@ model:
optimizer:
accumulate_grad_in_fp32: false
learning_rate_scheduler:
learning_rate: 1.0
learning_rate: 0.0006
lr_decay_starting_step: null
lr_decay_steps: null
lr_decay_style: cosine
lr_warmup_steps: 200 # 10% warm up of total training steps
lr_warmup_steps: 0 # 10% warm up of total training steps
lr_warmup_style: linear
min_decay_lr: 0.00006
optimizer_factory:
Expand All @@ -76,7 +76,7 @@ optimizer:
adam_eps: 1.0e-08
name: custom_adam
torch_adam_is_fused: true
weight_decay: 0.1
weight_decay: 0.
zero_stage: 0

parallelism:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ general:
consumed_train_samples: null
ignore_sanity_checks: true
project: fp8_for_nanotron
run: bfloat16_2_layers_and_seq_len_256_and_micro_batch_128_and_lr_2.0e-4_and_minipile_overfitting_and_main_branch_and_layernorm
run: bfloat16_2_layers_and_seq_len_256_and_micro_batch_128_and_lr_2.0e-4_and_minipile_overfitting_and_main_branch_and_layernorm_and_no_weight_decay_and_no_warmup
seed: 42
step: null
lighteval: null
Expand Down Expand Up @@ -98,7 +98,7 @@ optimizer:
lr_decay_starting_step: null
lr_decay_steps: null
lr_decay_style: cosine
lr_warmup_steps: 200 # 10% warm up of total training steps
lr_warmup_steps: 0 # 10% warm up of total training steps
lr_warmup_style: linear
min_decay_lr: 0.00006
optimizer_factory:
Expand All @@ -107,15 +107,15 @@ optimizer:
adam_eps: 1.0e-08
name: adamW
torch_adam_is_fused: true
weight_decay: 0.1
weight_decay: 0.
zero_stage: 0

parallelism:
dp: 1
expert_parallel_size: 1
pp: 1
pp_engine: 1f1b
tp: 2
tp: 1
tp_linear_async_communication: false
tp_mode: ALL_REDUCE

Expand Down
137 changes: 42 additions & 95 deletions src/nanotron/fp8/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,124 +73,71 @@ def step(self, closure=None):
loggings = {}
for group in self.param_groups:
for p in group["params"]:
# if p.grad is None:
# continue

loggings[p] = {}

assert (p.grad is not None and p.data.grad is not None) is False
grad = p.grad if p.grad is not None else p.data.grad
data = p.data

assert isinstance(grad, torch.Tensor)
loggings[p]["hp_grad"] = compute_stas(grad)
loggings[p]["hp_p"] = compute_stas(grad)

# try:
# assert isinstance(grad, torch.Tensor)
# except:
# assert 1 == 1

# if p.ndim != 1:
# print(f"[Ref Adam] original grad: {grad[:2, :2]} \n")

# if grad.is_sparse:
# raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")
amsgrad = group["amsgrad"]

state = self.state[p]
data = p.data
assert isinstance(data, torch.Tensor)

# State initialization
if len(state) == 0:
state["step"] = torch.tensor(0.0, dtype=torch.float32)
# Exponential moving average of gradient values
state["step"] = torch.tensor(0.0, dtype=data.dtype)
state["exp_avg"] = torch.zeros_like(data, memory_format=torch.preserve_format)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(data, memory_format=torch.preserve_format)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state["max_exp_avg_sq"] = torch.zeros_like(data, memory_format=torch.preserve_format)

exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
loggings[p] = {}

loggings[p]["hp_exp_avg"] = compute_stas(exp_avg)
loggings[p]["hp_exp_avg_sq"] = compute_stas(exp_avg_sq)
assert (p.grad is not None and p.data.grad is not None) is False
grad = p.grad if p.grad is not None else p.data.grad
assert isinstance(grad, torch.Tensor)

if amsgrad:
max_exp_avg_sq = state["max_exp_avg_sq"]
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"]

loggings[p]["group:beta1"] = {"value": beta1}
loggings[p]["group:beta2"] = {"value": beta2}
loggings[p]["group:lr"] = {"value": group["lr"]}
loggings[p]["group:eps"] = {"value": group["eps"]}

# if p.ndim != 1:
# print(
# f"[Ref Adam] original exp_avg: exp_avg.data={exp_avg.data[:2, :2]}, exp_avg.dtype={exp_avg.dtype} \n"
# )
# print(
# f"[Ref Adam] original exp_avg_sq: exp_avg_sq.data={exp_avg_sq.data[:2, :2]}, exp_avg_sq.dtype={exp_avg_sq.dtype} \n"
# )
# print(f"[Ref Adam] beta1: {beta1}, beta2: {beta2}")

state["step"] += 1
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]

loggings[p]["bias_correction1"] = {"value": bias_correction1}
loggings[p]["bias_correction2"] = {"value": bias_correction2}

# if p.ndim != 1:
# print(f"[Ref Adam]: bias_correction1: {bias_correction1}, bias_correction2: {bias_correction2}")

if group["weight_decay"] != 0:
grad = grad.add(group["weight_decay"], data)
# if p.ndim != 1:
# print(f"[Ref Adam] grad after weight decay: {grad[:2, :2]} \n")
# if group["weight_decay"] != 0:
# grad = grad.add(group["weight_decay"], data)

# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)

# if p.ndim != 1:
# print(f"[Ref Adam] after mul and add: exp_avg: {exp_avg[:2, :2]} \n")
# print(f"[Ref Adam] after mul and add: exp_avg_sq: {exp_avg_sq[:2, :2]} \n")

if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
# Use the max. for normalizing running avg. of gradient
denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group["eps"])
else:
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group["eps"])
# exp_avg.mul_(beta1).add_(1 - beta1, grad)
# exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)

# if p.ndim != 1:
# print(f"[Ref Adam] exp_avg_sq.sqrt(): {exp_avg_sq.sqrt()[:2, :2]} \n")
# print(f"[Ref Adam] math.sqrt(bias_correction2)): {math.sqrt(bias_correction2)} \n")
# print(f"[Ref Adam] group['eps']: {group['eps']} \n")
exp_avg = beta1 * exp_avg + (1 - beta1) * grad
exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad.pow(2)

step_size = group["lr"] / bias_correction1
loggings[p]["denom"] = compute_stas(denom)
loggings[p]["step_size"] = {"value": step_size}
step = state["step"]
step += 1
bias_correction1 = 1 - (beta1**step)
bias_correction2 = 1 - (beta2**step)

# if p.ndim != 1:
# print(f"[Ref Adam] step_size: {step_size} \n")
# print(f"[Ref Adam] exp_avg: {exp_avg[:2, :2]} \n")
# print(f"[Ref Adam] denom: {denom[:2, :2]} \n")
exp_avg = exp_avg / bias_correction1
exp_avg_sq = exp_avg_sq / bias_correction2

# denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group["eps"])
denom = exp_avg_sq.sqrt() + group["eps"]
normalized_grad = exp_avg / denom

lr = group["lr"]
# p.data.addcdiv_(-step_size, exp_avg, denom)
new_data = data - step_size * (exp_avg / denom)
new_data = data - lr * normalized_grad
new_data.requires_grad = True
p.data = new_data

assert p.data is new_data

# if p.ndim != 1:
# print(f"[Ref Adam] updated p: {p.data[:2, :2]} \n")
state["exp_avg"] = exp_avg
state["exp_avg_sq"] = exp_avg_sq
state["step"] = step

# break
# break
loggings[p]["hp_grad"] = compute_stas(grad)
loggings[p]["hp_p"] = compute_stas(p)
loggings[p]["group:lr"] = {"value": lr}
loggings[p]["group:eps"] = {"value": group["eps"]}
loggings[p]["hp_exp_avg"] = compute_stas(exp_avg)
loggings[p]["hp_exp_avg_sq"] = compute_stas(exp_avg_sq)
loggings[p]["group:beta1"] = {"value": beta1}
loggings[p]["group:beta2"] = {"value": beta2}

loggings[p]["bias_correction1"] = {"value": bias_correction1}
loggings[p]["bias_correction2"] = {"value": bias_correction2}
loggings[p]["denom"] = compute_stas(denom)
loggings[p]["normalized_grad"] = compute_stas(normalized_grad)

self.loggings = loggings

Expand Down
Binary file modified src/nanotron/nn/layer_norm.py
Binary file not shown.

0 comments on commit 6400e86

Please sign in to comment.