diff --git a/examples/fp8/ablations/configs/sanity_bf16.yaml b/examples/fp8/ablations/configs/sanity_bf16.yaml index 08d7ab5d..0b1c1880 100644 --- a/examples/fp8/ablations/configs/sanity_bf16.yaml +++ b/examples/fp8/ablations/configs/sanity_bf16.yaml @@ -73,9 +73,9 @@ model: intermediate_size: 2048 is_llama_config: true max_position_embeddings: 256 - num_attention_heads: 4 + num_attention_heads: 16 num_hidden_layers: 2 - num_key_value_heads: 4 + num_key_value_heads: 16 pad_token_id: null pretraining_tp: 1 rms_norm_eps: 1.0e-05 @@ -123,7 +123,7 @@ optimizer: lr_decay_starting_step: null lr_decay_steps: null lr_decay_style: cosine - lr_warmup_steps: 1000 # 10% warm up of total training steps + lr_warmup_steps: 200 # 10% warm up of total training steps lr_warmup_style: linear min_decay_lr: 0.00006 optimizer_factory: @@ -158,7 +158,7 @@ tokens: batch_accumulation_per_replica: 1 limit_test_batches: 0 limit_val_batches: 0 - micro_batch_size: 256 # 256 + micro_batch_size: 128 # 256 # micro_batch_size: 1 sequence_length: 256 train_steps: 24376 diff --git a/src/nanotron/fp8/optim.py b/src/nanotron/fp8/optim.py index a487118b..f013142f 100644 --- a/src/nanotron/fp8/optim.py +++ b/src/nanotron/fp8/optim.py @@ -167,7 +167,8 @@ def step(self, closure=None): # print(f"[Ref Adam] exp_avg: {exp_avg[:2, :2]} \n") # print(f"[Ref Adam] denom: {denom[:2, :2]} \n") - p.data.addcdiv_(-step_size, exp_avg, denom) + # p.data.addcdiv_(-step_size, exp_avg, denom) + p.data = p.data - step_size * (exp_avg / denom) # if p.ndim != 1: # print(f"[Ref Adam] updated p: {p.data[:2, :2]} \n") diff --git a/src/nanotron/helpers.py b/src/nanotron/helpers.py index 38d70714..02b1cd88 100644 --- a/src/nanotron/helpers.py +++ b/src/nanotron/helpers.py @@ -372,7 +372,24 @@ def optimizer(param_groups): else: def optimizer(param_groups): - return torch.optim.Adam( + # return torch.optim.Adam( + # param_groups, + # lr=optimizer_args.learning_rate_scheduler.learning_rate, + # weight_decay=optimizer_args.weight_decay, + # eps=optimizer_args.optimizer_factory.adam_eps, + # betas=( + # optimizer_args.optimizer_factory.adam_beta1, + # optimizer_args.optimizer_factory.adam_beta2, + # ), + # # fused=optimizer_args.optimizer_factory.torch_adam_is_fused, + # # NOTE: fused (bool, optional) – whether the fused implementation (CUDA only) is used. + # # Currently, torch.float64, torch.float32, torch.float16, and torch.bfloat16 + # # in FP8 training, model parameters are INT8 + # fused=False, + # ) + from nanotron.fp8.optim import Adam + + return Adam( param_groups, lr=optimizer_args.learning_rate_scheduler.learning_rate, weight_decay=optimizer_args.weight_decay, @@ -381,11 +398,6 @@ def optimizer(param_groups): optimizer_args.optimizer_factory.adam_beta1, optimizer_args.optimizer_factory.adam_beta2, ), - # fused=optimizer_args.optimizer_factory.torch_adam_is_fused, - # NOTE: fused (bool, optional) – whether the fused implementation (CUDA only) is used. - # Currently, torch.float64, torch.float32, torch.float16, and torch.bfloat16 - # in FP8 training, model parameters are INT8 - fused=False, ) elif optimizer_args.optimizer_factory.name == "sgd":