Skip to content

Commit

Permalink
logging
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Dec 12, 2024
1 parent 7529766 commit 9bbe332
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
2 changes: 1 addition & 1 deletion tutorials/examples/multinode/mila.ddp_gfn.small.4.slurm
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ echo " + Slurm Job Num Nodes: ${SLURM_JOB_NUM_NODES}"
echo " + Slurm NodeList: ${SLURM_NODELIST}"

#mpiexec.hydra -np 2 -ppn 2 -l -genv I_MPI_PIN_DOMAIN=[0xFFFF,0xFFFF0000] -genv CCL_WORKER_AFFINITY=32,48 -genv CCL_WORKER_COUNT=1 -genv O MP_NUM_THREADS=16 python -u train_hypergrid_multinode.py --ndim 8 --height 8 --R0 0.01 --tied --loss TB --n_trajectories 512000 --batch_size 256000
mpiexec.hydra -np 4 -ppn 4 -l -genv CCL_WORKER_COUNT=8 python -u ../train_hypergrid_multinode.py --ndim 8 --height 8 --R0 0.01 --tied --loss TB --n_trajectories 512000 --batch_size 256
mpiexec.hydra -np 4 -ppn 4 -l -genv CCL_WORKER_COUNT=8 python -u ../train_hypergrid.py --ndim 8 --height 8 --R0 0.01 --tied --loss TB --n_trajectories 512000 --batch_size 256

#./run_dist_ht.sh python -u train_hypergrid_multinode.py --ndim 8 --height 8 --R0 0.01 --tied --loss TB --n_trajectories 512000 --batch_size 256000

Expand Down
18 changes: 12 additions & 6 deletions tutorials/examples/train_hypergrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,20 +216,24 @@ def gather_distributed_data(

# Convert data to tensor if it's not already.
if not isinstance(local_data, torch.Tensor):
print("+ converting tensor data via serialization")
# Serialize complex data structures.
serialized_data = pickle.dumps(local_data)
local_tensor = torch.ByteTensor(torch.ByteStorage.from_buffer(serialized_data))
else:
print("+ tensor not converted")
local_tensor = local_data

# First gather sizes to allocate correct buffer sizes.
local_size = torch.tensor([local_tensor.numel()], device=local_tensor.device)
size_list = [
torch.tensor([0], device=local_tensor.device) for _ in range(world_size)
]
print("+ all gather of local_size={} to size_list".format(local_size))
dist.all_gather(size_list, local_size)

# Pad local tensor to maximum size.
print("+ padding local tensor")
max_size = max(size.item() for size in size_list)
if local_tensor.numel() < max_size:
padding = torch.zeros(
Expand All @@ -240,6 +244,7 @@ def gather_distributed_data(
local_tensor = torch.cat((local_tensor, padding))

# Gather all tensors.
print("+ gathering all tensors from world_size={}".format(world_size))
tensor_list = [
torch.zeros(max_size, dtype=local_tensor.dtype, device=local_tensor.device)
for _ in range(world_size)
Expand All @@ -251,12 +256,12 @@ def gather_distributed_data(
for tensor, size in zip(tensor_list, size_list):
trimmed_tensor = tensor[: size.item()]
if not isinstance(local_data, torch.Tensor):

print("+ deserializing tensor.")
# Deserialize data.
trimmed_data = pickle.loads(trimmed_tensor.cpu().numpy().tobytes())
result.append(trimmed_data)
result = pickle.loads(trimmed_tensor.cpu().numpy().tobytes())
else:
result.append(trimmed_tensor)
print("+ tensor not deserialized")
result = trimmed_tensor

return result

Expand Down Expand Up @@ -687,10 +692,11 @@ def validate_hypergrid(
env,
gflownet,
n_validation_samples,
visited_terminating_states: torch.Tensor,
visited_terminating_states: torch.Tensor | None,
discovered_modes,
):
#validation_info = validate( # Standard validation shared across envs.
# Standard validation shared across envs.
#validation_info, visited_terminating_states = validate(
# env,
# gflownet,
# n_validation_samples,
Expand Down

0 comments on commit 9bbe332

Please sign in to comment.