From 75d39e14dacd2700411cc6e04392d55e1ab126df Mon Sep 17 00:00:00 2001 From: Bowen <81504862+Bowen12992@users.noreply.github.com> Date: Mon, 2 Dec 2024 11:32:24 +0800 Subject: [PATCH] Update train.py --- flagscale/train/train.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/flagscale/train/train.py b/flagscale/train/train.py index 01c1bde0a..f100b911b 100644 --- a/flagscale/train/train.py +++ b/flagscale/train/train.py @@ -1330,14 +1330,14 @@ def get_e2e_base_metrics(): if args.profile and torch.distributed.get_rank() in args.profile_ranks and args.use_pytorch_profiler: prof = torch.profiler.profile( - schedule=torch.profiler.schedule( - wait=max(args.profile_step_start-1, 0), - warmup=1 if args.profile_step_start > 0 else 0, - active=args.profile_step_end-args.profile_step_start, - repeat=1), - on_trace_ready=torch.profiler.tensorboard_trace_handler(args.tensorboard_dir), - record_shapes=True, - with_stack=True) + schedule=torch.profiler.schedule( + wait=max(args.profile_step_start-1, 0), + warmup=1 if args.profile_step_start > 0 else 0, + active=args.profile_step_end-args.profile_step_start, + repeat=1), + on_trace_ready=torch.profiler.tensorboard_trace_handler(args.tensorboard_dir), + record_shapes=True, + with_stack=True) prof.start() while iteration < args.train_iters: