diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py index 72d9746..065994e 100644 --- a/tutorials/examples/train_hypergrid.py +++ b/tutorials/examples/train_hypergrid.py @@ -12,14 +12,19 @@ """ from argparse import ArgumentParser +from math import ceil +from typing import List, Any, Union, Optional, Callable import os - -import torch +import pickle +import signal +import sys +import threading import time + +from torch.nn.parallel import DistributedDataParallel as DDP from tqdm import tqdm, trange -from math import ceil +import torch import torch.distributed as dist -from torch.nn.parallel import DistributedDataParallel as DDP from gfn.containers import ReplayBuffer, PrioritizedReplayBuffer from gfn.gflownet import ( @@ -37,6 +42,7 @@ from gfn.utils.training import validate from torch.profiler import profile, ProfilerActivity + DEFAULT_SEED = 4444 @@ -50,8 +56,8 @@ def average_gradients(model): def initialize_distributed_compute(dist_backend: str = "ccl"): """Initalizes distributed compute using either ccl or mpi backends.""" - #global my_rank # TODO: remove globals? - #global my_size # TODO: remove globals? + # global my_rank # TODO: remove globals? + # global my_size # TODO: remove globals? pmi_size = int(os.environ.get("PMI_SIZE", "0")) # 0 or 1 default value? print("+ Initalizing distributed compute, PMI_SIZE={}".format(pmi_size)) @@ -96,9 +102,167 @@ def initialize_distributed_compute(dist_backend: str = "ccl"): return (my_rank, my_size) +class DistributedErrorHandler: + def __init__(self, + device_str: str, + rank: int, + world_size: int, + error_check_interval: float = 1.0, + cleanup_callback: Optional[Callable] = None, + ): + """ + Initialize error handler for distributed training. + + Args: + device_str: String representing the current device. + rank: Current process rank + world_size: Total number of processes + error_check_interval: How often to check for errors (in seconds) + cleanup_callback: Optional function to call before shutdown + """ + self.device_str = device_str + self.rank = rank + self.world_size = world_size + self.error_check_interval = error_check_interval + self.cleanup_callback = cleanup_callback + self.shutdown_flag = threading.Event() + self.error_tensor = torch.zeros(1, dtype=torch.uint8, device=self.device_str) + + # Set up error checking thread + self.checker_thread = threading.Thread(target=self._error_checker, daemon=True) + + # Register signal handlers + signal.signal(signal.SIGTERM, self._signal_handler) + signal.signal(signal.SIGINT, self._signal_handler) + + def start(self): + """Start error checking thread""" + self.checker_thread.start() + + def _signal_handler(self, signum, frame): + """Handle external signals""" + print(f'Process {self.rank} received signal {signum}') + self.shutdown_flag.set() + self._cleanup() + sys.exit(1) + + def _error_checker(self): + """Periodically check for errors across all processes""" + while not self.shutdown_flag.is_set(): + try: + # Use all_reduce to check if any process has errored + error_count = torch.zeros_like(self.error_tensor) + dist.all_reduce(error_count, op=dist.ReduceOp.SUM) + + if error_count.item() > 0: + print(f'Process {self.rank}: Detected error in another process') + self.shutdown_flag.set() + self._cleanup() + sys.exit(1) + + except Exception as e: + print(f'Process {self.rank}: Error in error checker: {str(e)}') + self.signal_error() + break + + time.sleep(self.error_check_interval) + + def signal_error(self): + """Signal that this process has encountered an error""" + try: + self.error_tensor.fill_(1) + dist.all_reduce(self.error_tensor, op=dist.ReduceOp.SUM) + except: + pass # If this fails, processes will eventually timeout + + self.shutdown_flag.set() + self._cleanup() + sys.exit(1) + + def _cleanup(self): + """Perform cleanup before shutdown""" + if self.cleanup_callback: + try: + self.cleanup_callback() + except Exception as e: + print(f'Process {self.rank}: Error in cleanup: {str(e)}') + + try: + dist.destroy_process_group() + except: + pass + + +def gather_distributed_data( + local_data: Union[List, torch.Tensor], world_size: int = None, rank: int = None +) -> List: + """ + Gather data from all processes in a distributed setting. + + Args: + local_data: Data from the current process (List or Tensor) + world_size: Number of processes (optional, will get from env if None) + rank: Current process rank (optional, will get from env if None) + + Returns: + List containing gathered data from all processes + """ + print("syncing distributed data") + + if world_size is None: + world_size = dist.get_world_size() + if rank is None: + rank = dist.get_rank() + + # Convert data to tensor if it's not already. + if not isinstance(local_data, torch.Tensor): + # Serialize complex data structures. + serialized_data = pickle.dumps(local_data) + local_tensor = torch.ByteTensor(torch.ByteStorage.from_buffer(serialized_data)) + else: + local_tensor = local_data + + # First gather sizes to allocate correct buffer sizes. + local_size = torch.tensor([local_tensor.numel()], device=local_tensor.device) + size_list = [ + torch.tensor([0], device=local_tensor.device) for _ in range(world_size) + ] + dist.all_gather(size_list, local_size) + + # Pad local tensor to maximum size. + max_size = max(size.item() for size in size_list) + if local_tensor.numel() < max_size: + padding = torch.zeros( + max_size - local_tensor.numel(), + dtype=local_tensor.dtype, + device=local_tensor.device, + ) + local_tensor = torch.cat((local_tensor, padding)) + + # Gather all tensors. + tensor_list = [ + torch.zeros(max_size, dtype=local_tensor.dtype, device=local_tensor.device) + for _ in range(world_size) + ] + dist.all_gather(tensor_list, local_tensor) + + # Trim padding and deserialize if necessary. + result = [] + for tensor, size in zip(tensor_list, size_list): + trimmed_tensor = tensor[: size.item()] + if not isinstance(local_data, torch.Tensor): + + # Deserialize data. + trimmed_data = pickle.loads(trimmed_tensor.cpu().numpy().tobytes()) + result.append(trimmed_data) + else: + result.append(trimmed_tensor) + + return result + + def main(args): # noqa: C901 seed = args.seed if args.seed != 0 else DEFAULT_SEED - set_seed(seed) device_str = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" use_wandb = args.wandb_project != "" @@ -114,13 +278,15 @@ def main(args): # noqa: C901 if args.distributed: my_rank, my_size = initialize_distributed_compute() - rank = dist.get_rank() + my_rank = dist.get_rank() world_size = torch.distributed.get_world_size() - print(f"Running with DDP on rank {rank}/{world_size}.") + print(f"Running with DDP on rank {my_rank}/{world_size}.") else: world_size = 1 # Single machine. my_rank = 0 # Single machine. + set_seed(seed + my_rank) + # 1. Create the environment env = HyperGrid( args.ndim, @@ -323,12 +489,30 @@ def main(args): # noqa: C901 if args.profile: keep_active = args.trajectories_to_profile // args.batch_size prof = profile( - schedule=torch.profiler.schedule(wait=1, warmup=1, active=keep_active, repeat=1), + schedule=torch.profiler.schedule( + wait=1, warmup=1, active=keep_active, repeat=1 + ), activities=[ProfilerActivity.CPU], record_shapes=True, - with_stack=True - ) + with_stack=True, + ) prof.start() + + if args.distributed: + # Create and start error handler. + def cleanup(): + print(f'Process {rank}: Cleaning up...') + + rank = os.environ["RANK"] + world_size = os.environ["WORLD_SIZE"] + handler = DistributedErrorHandler( + device_str, + rank, + world_size, + cleanup_callback=cleanup, + ) + handler.start() + for iteration in trange(n_iterations): iteration_start = time.time() @@ -385,7 +569,6 @@ def main(args): # noqa: C901 if args.distributed: loss = loss / (per_node_batch_size) - # Time backpropagation computation. loss_backward_start = time.time() loss.backward() @@ -416,7 +599,7 @@ def main(args): # noqa: C901 ] ) - # If we are on the master node. + # If we are on the master node, calculate the validation metrics. if my_rank == 0: to_log = { "loss": loss.item(), @@ -433,15 +616,30 @@ def main(args): # noqa: C901 if (iteration % args.validation_interval == 0) or ( iteration == n_iterations - 1 ): + + if args.distributed: + try: + all_visited_terminating_states = gather_distributed_data( + visited_terminating_states.tensor + ) + except Exception as e: + print(f'Process {my_rank}: Caught error: {str(e)}') + handler.signal_error() + sys.exit(1) + else: + all_visited_terminating_states = visited_terminating_states.tensor + validation_info, discovered_modes = validate_hypergrid( env, gflownet, args.validation_samples, - visited_terminating_states, + all_visited_terminating_states, discovered_modes, ) + if use_wandb: wandb.log(validation_info, step=iteration) + to_log.update(validation_info) tqdm.write(f"{iteration}: {to_log}") @@ -471,8 +669,8 @@ def main(args): # noqa: C901 } print("+ Final timing.") - for k, v in to_log.iteritems(): - print(" {k}: {.:6f}".format(k, v)) + for k, v in to_log.items(): + print(" {}: {:.6f}".format(k, v)) if args.profile: prof.stop() @@ -489,21 +687,22 @@ def validate_hypergrid( env, gflownet, n_validation_samples, - visited_terminating_states, + visited_terminating_states: torch.Tensor, discovered_modes, ): - validation_info = validate( # Standard validation shared across envs. - env, - gflownet, - n_validation_samples, - visited_terminating_states, - ) - - # # Add the mode counting metric. - states, scale = visited_terminating_states.tensor, env.scale_factor + #validation_info = validate( # Standard validation shared across envs. + # env, + # gflownet, + # n_validation_samples, + # visited_terminating_states, + #) + validation_info = {} + + # Add the mode counting metric. + states, scale = visited_terminating_states, env.scale_factor mode_reward_threshold = 1.0 # Assumes height >= 5. TODO - verify. - # # Modes will have a reward greater than 1. + # Modes will have a reward greater than 1. modes = states[env.reward(states) >= mode_reward_threshold] modes_found = set([tuple(s.tolist()) for s in modes]) discovered_modes.update(modes_found) @@ -709,8 +908,8 @@ def validate_hypergrid( "--trajectories_to_profile", type=int, default=2048, - help="Number of trajectories to profile using the Pytorch Profiler." + - " Preferably, a multiple of batch size.", + help="Number of trajectories to profile using the Pytorch Profiler." + + " Preferably, a multiple of batch size.", ) args = parser.parse_args()