diff --git a/tutorials/examples/multinode/mila.ddp_gfn.small.4.slurm b/tutorials/examples/multinode/mila.ddp_gfn.small.4.slurm index 3c8168c3..e4ebd53f 100644 --- a/tutorials/examples/multinode/mila.ddp_gfn.small.4.slurm +++ b/tutorials/examples/multinode/mila.ddp_gfn.small.4.slurm @@ -46,7 +46,7 @@ echo " + Slurm Job Num Nodes: ${SLURM_JOB_NUM_NODES}" echo " + Slurm NodeList: ${SLURM_NODELIST}" #mpiexec.hydra -np 2 -ppn 2 -l -genv I_MPI_PIN_DOMAIN=[0xFFFF,0xFFFF0000] -genv CCL_WORKER_AFFINITY=32,48 -genv CCL_WORKER_COUNT=1 -genv O MP_NUM_THREADS=16 python -u train_hypergrid_multinode.py --ndim 8 --height 8 --R0 0.01 --tied --loss TB --n_trajectories 512000 --batch_size 256000 -mpiexec.hydra -np 4 -ppn 4 -l -genv CCL_WORKER_COUNT=8 python -u ../train_hypergrid_multinode.py --ndim 8 --height 8 --R0 0.01 --tied --loss TB --n_trajectories 512000 --batch_size 256 +mpiexec.hydra -np 4 -ppn 4 -l -genv CCL_WORKER_COUNT=8 python -u ../train_hypergrid.py --ndim 8 --height 8 --R0 0.01 --tied --loss TB --n_trajectories 512000 --batch_size 256 #./run_dist_ht.sh python -u train_hypergrid_multinode.py --ndim 8 --height 8 --R0 0.01 --tied --loss TB --n_trajectories 512000 --batch_size 256000 diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py index 065994e3..a54c847a 100644 --- a/tutorials/examples/train_hypergrid.py +++ b/tutorials/examples/train_hypergrid.py @@ -216,10 +216,12 @@ def gather_distributed_data( # Convert data to tensor if it's not already. if not isinstance(local_data, torch.Tensor): + print("+ converting tensor data via serialization") # Serialize complex data structures. serialized_data = pickle.dumps(local_data) local_tensor = torch.ByteTensor(torch.ByteStorage.from_buffer(serialized_data)) else: + print("+ tensor not converted") local_tensor = local_data # First gather sizes to allocate correct buffer sizes. @@ -227,9 +229,11 @@ def gather_distributed_data( size_list = [ torch.tensor([0], device=local_tensor.device) for _ in range(world_size) ] + print("+ all gather of local_size={} to size_list".format(local_size)) dist.all_gather(size_list, local_size) # Pad local tensor to maximum size. + print("+ padding local tensor") max_size = max(size.item() for size in size_list) if local_tensor.numel() < max_size: padding = torch.zeros( @@ -240,6 +244,7 @@ def gather_distributed_data( local_tensor = torch.cat((local_tensor, padding)) # Gather all tensors. + print("+ gathering all tensors from world_size={}".format(world_size)) tensor_list = [ torch.zeros(max_size, dtype=local_tensor.dtype, device=local_tensor.device) for _ in range(world_size) @@ -251,12 +256,12 @@ def gather_distributed_data( for tensor, size in zip(tensor_list, size_list): trimmed_tensor = tensor[: size.item()] if not isinstance(local_data, torch.Tensor): - + print("+ deserializing tensor.") # Deserialize data. - trimmed_data = pickle.loads(trimmed_tensor.cpu().numpy().tobytes()) - result.append(trimmed_data) + result = pickle.loads(trimmed_tensor.cpu().numpy().tobytes()) else: - result.append(trimmed_tensor) + print("+ tensor not deserialized") + result = trimmed_tensor return result @@ -687,10 +692,11 @@ def validate_hypergrid( env, gflownet, n_validation_samples, - visited_terminating_states: torch.Tensor, + visited_terminating_states: torch.Tensor | None, discovered_modes, ): - #validation_info = validate( # Standard validation shared across envs. + # Standard validation shared across envs. + #validation_info, visited_terminating_states = validate( # env, # gflownet, # n_validation_samples,