Skip to content

Commit

Permalink
small change
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Dec 10, 2024
1 parent ebcf20b commit bdc36d4
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions tutorials/examples/train_hypergrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ def average_gradients(model):
param.grad.data /= size


def initalize_distributed_compute(dist_backend: str = "ccl"):
def initialize_distributed_compute(dist_backend: str = "ccl"):
"""Initalizes distributed compute using either ccl or mpi backends."""
global my_rank # TODO: remove globals?
global my_size # TODO: remove globals?
#global my_rank # TODO: remove globals?
#global my_size # TODO: remove globals?

pmi_size = int(os.environ.get("PMI_SIZE", "0")) # 0 or 1 default value?
print("+ Initalizing distributed compute, PMI_SIZE={}".format(pmi_size))
Expand Down Expand Up @@ -90,8 +90,11 @@ def initalize_distributed_compute(dist_backend: str = "ccl"):

my_rank = dist.get_rank() # Global!
my_size = dist.get_world_size() # Global!

print(f"+ My rank: {my_rank} size: {my_size}")

return (my_rank, my_size)


def main(args): # noqa: C901
seed = args.seed if args.seed != 0 else DEFAULT_SEED
Expand All @@ -110,7 +113,7 @@ def main(args): # noqa: C901
wandb.config.update(args)

if args.distributed:
initalize_distributed_compute()
my_rank, my_size = initialize_distributed_compute()
rank = dist.get_rank()
world_size = torch.distributed.get_world_size()
print(f"Running with DDP on rank {rank}/{world_size}.")
Expand Down

0 comments on commit bdc36d4

Please sign in to comment.