diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py index 67ce0e86..72d9746d 100644 --- a/tutorials/examples/train_hypergrid.py +++ b/tutorials/examples/train_hypergrid.py @@ -48,10 +48,10 @@ def average_gradients(model): param.grad.data /= size -def initalize_distributed_compute(dist_backend: str = "ccl"): +def initialize_distributed_compute(dist_backend: str = "ccl"): """Initalizes distributed compute using either ccl or mpi backends.""" - global my_rank # TODO: remove globals? - global my_size # TODO: remove globals? + #global my_rank # TODO: remove globals? + #global my_size # TODO: remove globals? pmi_size = int(os.environ.get("PMI_SIZE", "0")) # 0 or 1 default value? print("+ Initalizing distributed compute, PMI_SIZE={}".format(pmi_size)) @@ -90,8 +90,11 @@ def initalize_distributed_compute(dist_backend: str = "ccl"): my_rank = dist.get_rank() # Global! my_size = dist.get_world_size() # Global! + print(f"+ My rank: {my_rank} size: {my_size}") + return (my_rank, my_size) + def main(args): # noqa: C901 seed = args.seed if args.seed != 0 else DEFAULT_SEED @@ -110,7 +113,7 @@ def main(args): # noqa: C901 wandb.config.update(args) if args.distributed: - initalize_distributed_compute() + my_rank, my_size = initialize_distributed_compute() rank = dist.get_rank() world_size = torch.distributed.get_world_size() print(f"Running with DDP on rank {rank}/{world_size}.")