Skip to content

Commit

Permalink
major changes for better (?) distributed metric tracking -- still nee…
Browse files Browse the repository at this point in the history
…ds to be debugged
  • Loading branch information
josephdviviano committed Dec 12, 2024
2 parents bdc36d4 + 089bb70 commit a7355f2
Showing 1 changed file with 229 additions and 30 deletions.
259 changes: 229 additions & 30 deletions tutorials/examples/train_hypergrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,19 @@
"""

from argparse import ArgumentParser
from math import ceil
from typing import List, Any, Union, Optional, Callable
import os

import torch
import pickle
import signal
import sys
import threading
import time

from torch.nn.parallel import DistributedDataParallel as DDP
from tqdm import tqdm, trange
from math import ceil
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

from gfn.containers import ReplayBuffer, PrioritizedReplayBuffer
from gfn.gflownet import (
Expand All @@ -37,6 +42,7 @@
from gfn.utils.training import validate
from torch.profiler import profile, ProfilerActivity


DEFAULT_SEED = 4444


Expand All @@ -50,8 +56,8 @@ def average_gradients(model):

def initialize_distributed_compute(dist_backend: str = "ccl"):
"""Initalizes distributed compute using either ccl or mpi backends."""
#global my_rank # TODO: remove globals?
#global my_size # TODO: remove globals?
# global my_rank # TODO: remove globals?
# global my_size # TODO: remove globals?

pmi_size = int(os.environ.get("PMI_SIZE", "0")) # 0 or 1 default value?
print("+ Initalizing distributed compute, PMI_SIZE={}".format(pmi_size))
Expand Down Expand Up @@ -96,9 +102,167 @@ def initialize_distributed_compute(dist_backend: str = "ccl"):
return (my_rank, my_size)


class DistributedErrorHandler:
def __init__(self,
device_str: str,
rank: int,
world_size: int,
error_check_interval: float = 1.0,
cleanup_callback: Optional[Callable] = None,
):
"""
Initialize error handler for distributed training.
Args:
device_str: String representing the current device.
rank: Current process rank
world_size: Total number of processes
error_check_interval: How often to check for errors (in seconds)
cleanup_callback: Optional function to call before shutdown
"""
self.device_str = device_str
self.rank = rank
self.world_size = world_size
self.error_check_interval = error_check_interval
self.cleanup_callback = cleanup_callback
self.shutdown_flag = threading.Event()
self.error_tensor = torch.zeros(1, dtype=torch.uint8, device=self.device_str)

# Set up error checking thread
self.checker_thread = threading.Thread(target=self._error_checker, daemon=True)

# Register signal handlers
signal.signal(signal.SIGTERM, self._signal_handler)
signal.signal(signal.SIGINT, self._signal_handler)

def start(self):
"""Start error checking thread"""
self.checker_thread.start()

def _signal_handler(self, signum, frame):
"""Handle external signals"""
print(f'Process {self.rank} received signal {signum}')
self.shutdown_flag.set()
self._cleanup()
sys.exit(1)

def _error_checker(self):
"""Periodically check for errors across all processes"""
while not self.shutdown_flag.is_set():
try:
# Use all_reduce to check if any process has errored
error_count = torch.zeros_like(self.error_tensor)
dist.all_reduce(error_count, op=dist.ReduceOp.SUM)

if error_count.item() > 0:
print(f'Process {self.rank}: Detected error in another process')
self.shutdown_flag.set()
self._cleanup()
sys.exit(1)

except Exception as e:
print(f'Process {self.rank}: Error in error checker: {str(e)}')
self.signal_error()
break

time.sleep(self.error_check_interval)

def signal_error(self):
"""Signal that this process has encountered an error"""
try:
self.error_tensor.fill_(1)
dist.all_reduce(self.error_tensor, op=dist.ReduceOp.SUM)
except:
pass # If this fails, processes will eventually timeout

self.shutdown_flag.set()
self._cleanup()
sys.exit(1)

def _cleanup(self):
"""Perform cleanup before shutdown"""
if self.cleanup_callback:
try:
self.cleanup_callback()
except Exception as e:
print(f'Process {self.rank}: Error in cleanup: {str(e)}')

try:
dist.destroy_process_group()
except:
pass


def gather_distributed_data(
local_data: Union[List, torch.Tensor], world_size: int = None, rank: int = None
) -> List:
"""
Gather data from all processes in a distributed setting.
Args:
local_data: Data from the current process (List or Tensor)
world_size: Number of processes (optional, will get from env if None)
rank: Current process rank (optional, will get from env if None)
Returns:
List containing gathered data from all processes
"""
print("syncing distributed data")

if world_size is None:
world_size = dist.get_world_size()
if rank is None:
rank = dist.get_rank()

# Convert data to tensor if it's not already.
if not isinstance(local_data, torch.Tensor):
# Serialize complex data structures.
serialized_data = pickle.dumps(local_data)
local_tensor = torch.ByteTensor(torch.ByteStorage.from_buffer(serialized_data))
else:
local_tensor = local_data

# First gather sizes to allocate correct buffer sizes.
local_size = torch.tensor([local_tensor.numel()], device=local_tensor.device)
size_list = [
torch.tensor([0], device=local_tensor.device) for _ in range(world_size)
]
dist.all_gather(size_list, local_size)

# Pad local tensor to maximum size.
max_size = max(size.item() for size in size_list)
if local_tensor.numel() < max_size:
padding = torch.zeros(
max_size - local_tensor.numel(),
dtype=local_tensor.dtype,
device=local_tensor.device,
)
local_tensor = torch.cat((local_tensor, padding))

# Gather all tensors.
tensor_list = [
torch.zeros(max_size, dtype=local_tensor.dtype, device=local_tensor.device)
for _ in range(world_size)
]
dist.all_gather(tensor_list, local_tensor)

# Trim padding and deserialize if necessary.
result = []
for tensor, size in zip(tensor_list, size_list):
trimmed_tensor = tensor[: size.item()]
if not isinstance(local_data, torch.Tensor):

# Deserialize data.
trimmed_data = pickle.loads(trimmed_tensor.cpu().numpy().tobytes())
result.append(trimmed_data)
else:
result.append(trimmed_tensor)

return result


def main(args): # noqa: C901
seed = args.seed if args.seed != 0 else DEFAULT_SEED
set_seed(seed)
device_str = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"

use_wandb = args.wandb_project != ""
Expand All @@ -114,13 +278,15 @@ def main(args): # noqa: C901

if args.distributed:
my_rank, my_size = initialize_distributed_compute()
rank = dist.get_rank()
my_rank = dist.get_rank()
world_size = torch.distributed.get_world_size()
print(f"Running with DDP on rank {rank}/{world_size}.")
print(f"Running with DDP on rank {my_rank}/{world_size}.")
else:
world_size = 1 # Single machine.
my_rank = 0 # Single machine.

set_seed(seed + my_rank)

# 1. Create the environment
env = HyperGrid(
args.ndim,
Expand Down Expand Up @@ -323,12 +489,30 @@ def main(args): # noqa: C901
if args.profile:
keep_active = args.trajectories_to_profile // args.batch_size
prof = profile(
schedule=torch.profiler.schedule(wait=1, warmup=1, active=keep_active, repeat=1),
schedule=torch.profiler.schedule(
wait=1, warmup=1, active=keep_active, repeat=1
),
activities=[ProfilerActivity.CPU],
record_shapes=True,
with_stack=True
)
with_stack=True,
)
prof.start()

if args.distributed:
# Create and start error handler.
def cleanup():
print(f'Process {rank}: Cleaning up...')

rank = os.environ["RANK"]
world_size = os.environ["WORLD_SIZE"]
handler = DistributedErrorHandler(
device_str,
rank,
world_size,
cleanup_callback=cleanup,
)
handler.start()

for iteration in trange(n_iterations):

iteration_start = time.time()
Expand Down Expand Up @@ -385,7 +569,6 @@ def main(args): # noqa: C901
if args.distributed:
loss = loss / (per_node_batch_size)


# Time backpropagation computation.
loss_backward_start = time.time()
loss.backward()
Expand Down Expand Up @@ -416,7 +599,7 @@ def main(args): # noqa: C901
]
)

# If we are on the master node.
# If we are on the master node, calculate the validation metrics.
if my_rank == 0:
to_log = {
"loss": loss.item(),
Expand All @@ -433,15 +616,30 @@ def main(args): # noqa: C901
if (iteration % args.validation_interval == 0) or (
iteration == n_iterations - 1
):

if args.distributed:
try:
all_visited_terminating_states = gather_distributed_data(
visited_terminating_states.tensor
)
except Exception as e:
print(f'Process {my_rank}: Caught error: {str(e)}')
handler.signal_error()
sys.exit(1)
else:
all_visited_terminating_states = visited_terminating_states.tensor

validation_info, discovered_modes = validate_hypergrid(
env,
gflownet,
args.validation_samples,
visited_terminating_states,
all_visited_terminating_states,
discovered_modes,
)

if use_wandb:
wandb.log(validation_info, step=iteration)

to_log.update(validation_info)
tqdm.write(f"{iteration}: {to_log}")

Expand Down Expand Up @@ -471,8 +669,8 @@ def main(args): # noqa: C901
}

print("+ Final timing.")
for k, v in to_log.iteritems():
print(" {k}: {.:6f}".format(k, v))
for k, v in to_log.items():
print(" {}: {:.6f}".format(k, v))

if args.profile:
prof.stop()
Expand All @@ -489,21 +687,22 @@ def validate_hypergrid(
env,
gflownet,
n_validation_samples,
visited_terminating_states,
visited_terminating_states: torch.Tensor,
discovered_modes,
):
validation_info = validate( # Standard validation shared across envs.
env,
gflownet,
n_validation_samples,
visited_terminating_states,
)

# # Add the mode counting metric.
states, scale = visited_terminating_states.tensor, env.scale_factor
#validation_info = validate( # Standard validation shared across envs.
# env,
# gflownet,
# n_validation_samples,
# visited_terminating_states,
#)
validation_info = {}

# Add the mode counting metric.
states, scale = visited_terminating_states, env.scale_factor
mode_reward_threshold = 1.0 # Assumes height >= 5. TODO - verify.

# # Modes will have a reward greater than 1.
# Modes will have a reward greater than 1.
modes = states[env.reward(states) >= mode_reward_threshold]
modes_found = set([tuple(s.tolist()) for s in modes])
discovered_modes.update(modes_found)
Expand Down Expand Up @@ -709,8 +908,8 @@ def validate_hypergrid(
"--trajectories_to_profile",
type=int,
default=2048,
help="Number of trajectories to profile using the Pytorch Profiler." +
" Preferably, a multiple of batch size.",
help="Number of trajectories to profile using the Pytorch Profiler."
+ " Preferably, a multiple of batch size.",
)

args = parser.parse_args()
Expand Down

0 comments on commit a7355f2

Please sign in to comment.