diff --git a/pipeline/train/instruction_following.py b/pipeline/train/instruction_following.py index fca6da5f..87664fcb 100755 --- a/pipeline/train/instruction_following.py +++ b/pipeline/train/instruction_following.py @@ -468,8 +468,7 @@ def main(): print(f"Total training steps: {total_training_steps}") args.warmup_steps = total_training_steps * args.warmup_steps_ratio if args.warmup_steps_ratio is not None else args.warmup_steps - args.warmup_steps = args.warmup_steps // args.gradient_accumulation_steps - args.total_training_steps = total_training_steps // args.gradient_accumulation_steps + args.total_training_steps = total_training_steps if args.lr_scheduler == "linear": lr_scheduler = get_linear_schedule_with_warmup(