Skip to content

Commit

Permalink
Update train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Bowen12992 authored Dec 2, 2024
1 parent c4a809b commit 75d39e1
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions flagscale/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -1330,14 +1330,14 @@ def get_e2e_base_metrics():

if args.profile and torch.distributed.get_rank() in args.profile_ranks and args.use_pytorch_profiler:
prof = torch.profiler.profile(
schedule=torch.profiler.schedule(
wait=max(args.profile_step_start-1, 0),
warmup=1 if args.profile_step_start > 0 else 0,
active=args.profile_step_end-args.profile_step_start,
repeat=1),
on_trace_ready=torch.profiler.tensorboard_trace_handler(args.tensorboard_dir),
record_shapes=True,
with_stack=True)
schedule=torch.profiler.schedule(
wait=max(args.profile_step_start-1, 0),
warmup=1 if args.profile_step_start > 0 else 0,
active=args.profile_step_end-args.profile_step_start,
repeat=1),
on_trace_ready=torch.profiler.tensorboard_trace_handler(args.tensorboard_dir),
record_shapes=True,
with_stack=True)
prof.start()

while iteration < args.train_iters:
Expand Down

0 comments on commit 75d39e1

Please sign in to comment.