From 7eb36a11b3a9c48ed07b93692ccf22bfb5577f7e Mon Sep 17 00:00:00 2001 From: Polisetty V R K Jyothendra Varma Date: Mon, 13 May 2024 15:25:56 +0530 Subject: [PATCH] get distributed backend name via accelerator and check loss_scale before writing to tb (#374) * check loss_scale before writing to tb * get distributed backend name via accelerator * add hccl distributed backend support --- megatron/arguments.py | 2 +- megatron/initialize.py | 2 +- megatron/training.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/megatron/arguments.py b/megatron/arguments.py index 631d4b12e8..dad993be04 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -1138,7 +1138,7 @@ def _add_distributed_args(parser): help='overlap pipeline parallel communication with forward and backward chunks', dest='overlap_p2p_comm') group.add_argument('--distributed-backend', default='nccl', - choices=['nccl', 'gloo', 'ccl'], + choices=['nccl', 'gloo', 'ccl', 'hccl'], help='Which backend to use for distributed training.') group.add_argument('--distributed-timeout-minutes', type=int, default=10, help='Timeout minutes for torch.distributed.') diff --git a/megatron/initialize.py b/megatron/initialize.py index 31f26c5086..85c37a5767 100644 --- a/megatron/initialize.py +++ b/megatron/initialize.py @@ -222,7 +222,7 @@ def _initialize_distributed(): else: if not torch.distributed.is_initialized(): torch.distributed.init_process_group( - backend=args.distributed_backend, + backend=get_accelerator().communication_backend_name(), world_size=args.world_size, rank=args.rank, timeout=timedelta(minutes=args.distributed_timeout_minutes)) diff --git a/megatron/training.py b/megatron/training.py index 7b6acffede..19b8a6c71f 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -887,7 +887,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, args.consumed_train_samples) writer.add_scalar(f"lm-loss-training/{key}" + ' vs tokens', loss_dict[key], args.consumed_train_tokens) - if args.fp16 and args.log_loss_scale_to_tensorboard: + if args.fp16 and loss_scale and args.log_loss_scale_to_tensorboard: writer.add_scalar('loss-scale/loss-scale', loss_scale, iteration) writer.add_scalar('loss-scale/loss-scale vs samples', loss_scale, args.consumed_train_samples)