diff --git a/examples/fp8/ablations/configs/sanity_bf16.yaml b/examples/fp8/ablations/configs/sanity_bf16.yaml index bdafac88..7b5388ad 100644 --- a/examples/fp8/ablations/configs/sanity_bf16.yaml +++ b/examples/fp8/ablations/configs/sanity_bf16.yaml @@ -22,7 +22,7 @@ general: consumed_train_samples: null ignore_sanity_checks: true project: fp8_for_nanotron - run: bfloat16_2_layers_and_seq_len_256_and_micro_batch_128_and_lr_2.0e-4_and_minipile_overfitting_and_fp8_branch_and_layernorm_and_custom_adam_and_tp_1_and_lr_1.0 + run: bfloat16_2_layers_and_seq_len_256_and_micro_batch_128_and_lr_2.0e-4_and_minipile_overfitting_and_fp8_branch_and_layernorm_and_custom_adam_and_tp_1_and_no_weight_decay seed: 42 step: null lighteval: null @@ -30,7 +30,7 @@ logging: iteration_step_info_interval: 1 log_level: info log_level_replica: info - monitor_model_states: false + monitor_model_states: true model: ddp_bucket_cap_mb: 25 dtype: bfloat16 @@ -63,11 +63,11 @@ model: optimizer: accumulate_grad_in_fp32: false learning_rate_scheduler: - learning_rate: 1.0 + learning_rate: 0.0006 lr_decay_starting_step: null lr_decay_steps: null lr_decay_style: cosine - lr_warmup_steps: 200 # 10% warm up of total training steps + lr_warmup_steps: 0 # 10% warm up of total training steps lr_warmup_style: linear min_decay_lr: 0.00006 optimizer_factory: @@ -76,7 +76,7 @@ optimizer: adam_eps: 1.0e-08 name: custom_adam torch_adam_is_fused: true - weight_decay: 0.1 + weight_decay: 0. zero_stage: 0 parallelism: diff --git a/examples/fp8/ablations/configs/sanity_bf16_for_main_branch.yaml b/examples/fp8/ablations/configs/sanity_bf16_for_main_branch.yaml index 43a30d0a..27f5d7c9 100644 --- a/examples/fp8/ablations/configs/sanity_bf16_for_main_branch.yaml +++ b/examples/fp8/ablations/configs/sanity_bf16_for_main_branch.yaml @@ -22,7 +22,7 @@ general: consumed_train_samples: null ignore_sanity_checks: true project: fp8_for_nanotron - run: bfloat16_2_layers_and_seq_len_256_and_micro_batch_128_and_lr_2.0e-4_and_minipile_overfitting_and_main_branch_and_layernorm + run: bfloat16_2_layers_and_seq_len_256_and_micro_batch_128_and_lr_2.0e-4_and_minipile_overfitting_and_main_branch_and_layernorm_and_no_weight_decay_and_no_warmup seed: 42 step: null lighteval: null @@ -98,7 +98,7 @@ optimizer: lr_decay_starting_step: null lr_decay_steps: null lr_decay_style: cosine - lr_warmup_steps: 200 # 10% warm up of total training steps + lr_warmup_steps: 0 # 10% warm up of total training steps lr_warmup_style: linear min_decay_lr: 0.00006 optimizer_factory: @@ -107,7 +107,7 @@ optimizer: adam_eps: 1.0e-08 name: adamW torch_adam_is_fused: true - weight_decay: 0.1 + weight_decay: 0. zero_stage: 0 parallelism: @@ -115,7 +115,7 @@ parallelism: expert_parallel_size: 1 pp: 1 pp_engine: 1f1b - tp: 2 + tp: 1 tp_linear_async_communication: false tp_mode: ALL_REDUCE diff --git a/src/nanotron/fp8/optim.py b/src/nanotron/fp8/optim.py index 1e72c412..65f1abe6 100644 --- a/src/nanotron/fp8/optim.py +++ b/src/nanotron/fp8/optim.py @@ -73,124 +73,71 @@ def step(self, closure=None): loggings = {} for group in self.param_groups: for p in group["params"]: - # if p.grad is None: - # continue - - loggings[p] = {} - - assert (p.grad is not None and p.data.grad is not None) is False - grad = p.grad if p.grad is not None else p.data.grad - data = p.data - - assert isinstance(grad, torch.Tensor) - loggings[p]["hp_grad"] = compute_stas(grad) - loggings[p]["hp_p"] = compute_stas(grad) - - # try: - # assert isinstance(grad, torch.Tensor) - # except: - # assert 1 == 1 - - # if p.ndim != 1: - # print(f"[Ref Adam] original grad: {grad[:2, :2]} \n") - - # if grad.is_sparse: - # raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead") - amsgrad = group["amsgrad"] - state = self.state[p] + data = p.data + assert isinstance(data, torch.Tensor) - # State initialization if len(state) == 0: - state["step"] = torch.tensor(0.0, dtype=torch.float32) - # Exponential moving average of gradient values + state["step"] = torch.tensor(0.0, dtype=data.dtype) state["exp_avg"] = torch.zeros_like(data, memory_format=torch.preserve_format) - # Exponential moving average of squared gradient values state["exp_avg_sq"] = torch.zeros_like(data, memory_format=torch.preserve_format) - if amsgrad: - # Maintains max of all exp. moving avg. of sq. grad. values - state["max_exp_avg_sq"] = torch.zeros_like(data, memory_format=torch.preserve_format) - exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + loggings[p] = {} - loggings[p]["hp_exp_avg"] = compute_stas(exp_avg) - loggings[p]["hp_exp_avg_sq"] = compute_stas(exp_avg_sq) + assert (p.grad is not None and p.data.grad is not None) is False + grad = p.grad if p.grad is not None else p.data.grad + assert isinstance(grad, torch.Tensor) - if amsgrad: - max_exp_avg_sq = state["max_exp_avg_sq"] + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] beta1, beta2 = group["betas"] - loggings[p]["group:beta1"] = {"value": beta1} - loggings[p]["group:beta2"] = {"value": beta2} - loggings[p]["group:lr"] = {"value": group["lr"]} - loggings[p]["group:eps"] = {"value": group["eps"]} - - # if p.ndim != 1: - # print( - # f"[Ref Adam] original exp_avg: exp_avg.data={exp_avg.data[:2, :2]}, exp_avg.dtype={exp_avg.dtype} \n" - # ) - # print( - # f"[Ref Adam] original exp_avg_sq: exp_avg_sq.data={exp_avg_sq.data[:2, :2]}, exp_avg_sq.dtype={exp_avg_sq.dtype} \n" - # ) - # print(f"[Ref Adam] beta1: {beta1}, beta2: {beta2}") - - state["step"] += 1 - bias_correction1 = 1 - beta1 ** state["step"] - bias_correction2 = 1 - beta2 ** state["step"] - - loggings[p]["bias_correction1"] = {"value": bias_correction1} - loggings[p]["bias_correction2"] = {"value": bias_correction2} - - # if p.ndim != 1: - # print(f"[Ref Adam]: bias_correction1: {bias_correction1}, bias_correction2: {bias_correction2}") - - if group["weight_decay"] != 0: - grad = grad.add(group["weight_decay"], data) - # if p.ndim != 1: - # print(f"[Ref Adam] grad after weight decay: {grad[:2, :2]} \n") + # if group["weight_decay"] != 0: + # grad = grad.add(group["weight_decay"], data) # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(1 - beta1, grad) - exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) - - # if p.ndim != 1: - # print(f"[Ref Adam] after mul and add: exp_avg: {exp_avg[:2, :2]} \n") - # print(f"[Ref Adam] after mul and add: exp_avg_sq: {exp_avg_sq[:2, :2]} \n") - - if amsgrad: - # Maintains the maximum of all 2nd moment running avg. till now - torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) - # Use the max. for normalizing running avg. of gradient - denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group["eps"]) - else: - denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group["eps"]) + # exp_avg.mul_(beta1).add_(1 - beta1, grad) + # exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) - # if p.ndim != 1: - # print(f"[Ref Adam] exp_avg_sq.sqrt(): {exp_avg_sq.sqrt()[:2, :2]} \n") - # print(f"[Ref Adam] math.sqrt(bias_correction2)): {math.sqrt(bias_correction2)} \n") - # print(f"[Ref Adam] group['eps']: {group['eps']} \n") + exp_avg = beta1 * exp_avg + (1 - beta1) * grad + exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad.pow(2) - step_size = group["lr"] / bias_correction1 - loggings[p]["denom"] = compute_stas(denom) - loggings[p]["step_size"] = {"value": step_size} + step = state["step"] + step += 1 + bias_correction1 = 1 - (beta1**step) + bias_correction2 = 1 - (beta2**step) - # if p.ndim != 1: - # print(f"[Ref Adam] step_size: {step_size} \n") - # print(f"[Ref Adam] exp_avg: {exp_avg[:2, :2]} \n") - # print(f"[Ref Adam] denom: {denom[:2, :2]} \n") + exp_avg = exp_avg / bias_correction1 + exp_avg_sq = exp_avg_sq / bias_correction2 + # denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group["eps"]) + denom = exp_avg_sq.sqrt() + group["eps"] + normalized_grad = exp_avg / denom + + lr = group["lr"] # p.data.addcdiv_(-step_size, exp_avg, denom) - new_data = data - step_size * (exp_avg / denom) + new_data = data - lr * normalized_grad new_data.requires_grad = True p.data = new_data assert p.data is new_data - # if p.ndim != 1: - # print(f"[Ref Adam] updated p: {p.data[:2, :2]} \n") + state["exp_avg"] = exp_avg + state["exp_avg_sq"] = exp_avg_sq + state["step"] = step - # break - # break + loggings[p]["hp_grad"] = compute_stas(grad) + loggings[p]["hp_p"] = compute_stas(p) + loggings[p]["group:lr"] = {"value": lr} + loggings[p]["group:eps"] = {"value": group["eps"]} + loggings[p]["hp_exp_avg"] = compute_stas(exp_avg) + loggings[p]["hp_exp_avg_sq"] = compute_stas(exp_avg_sq) + loggings[p]["group:beta1"] = {"value": beta1} + loggings[p]["group:beta2"] = {"value": beta2} + + loggings[p]["bias_correction1"] = {"value": bias_correction1} + loggings[p]["bias_correction2"] = {"value": bias_correction2} + loggings[p]["denom"] = compute_stas(denom) + loggings[p]["normalized_grad"] = compute_stas(normalized_grad) self.loggings = loggings diff --git a/src/nanotron/nn/layer_norm.py b/src/nanotron/nn/layer_norm.py index 688eaa78..9f88f092 100644 Binary files a/src/nanotron/nn/layer_norm.py and b/src/nanotron/nn/layer_norm.py differ