Skip to content

Commit

Permalink
get distributed backend name via accelerator and check loss_scale bef…
Browse files Browse the repository at this point in the history
…ore writing to tb (microsoft#374)

* check loss_scale before writing to tb

* get distributed backend name via accelerator

* add hccl distributed backend support
  • Loading branch information
polisettyvarma authored May 13, 2024
1 parent bcedecd commit 7eb36a1
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1138,7 +1138,7 @@ def _add_distributed_args(parser):
help='overlap pipeline parallel communication with forward and backward chunks',
dest='overlap_p2p_comm')
group.add_argument('--distributed-backend', default='nccl',
choices=['nccl', 'gloo', 'ccl'],
choices=['nccl', 'gloo', 'ccl', 'hccl'],
help='Which backend to use for distributed training.')
group.add_argument('--distributed-timeout-minutes', type=int, default=10,
help='Timeout minutes for torch.distributed.')
Expand Down
2 changes: 1 addition & 1 deletion megatron/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def _initialize_distributed():
else:
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(
backend=args.distributed_backend,
backend=get_accelerator().communication_backend_name(),
world_size=args.world_size, rank=args.rank,
timeout=timedelta(minutes=args.distributed_timeout_minutes))

Expand Down
2 changes: 1 addition & 1 deletion megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,7 +887,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
args.consumed_train_samples)
writer.add_scalar(f"lm-loss-training/{key}" + ' vs tokens', loss_dict[key],
args.consumed_train_tokens)
if args.fp16 and args.log_loss_scale_to_tensorboard:
if args.fp16 and loss_scale and args.log_loss_scale_to_tensorboard:
writer.add_scalar('loss-scale/loss-scale', loss_scale, iteration)
writer.add_scalar('loss-scale/loss-scale vs samples', loss_scale,
args.consumed_train_samples)
Expand Down

0 comments on commit 7eb36a1

Please sign in to comment.