diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py index 59ab5b0..d233c23 100644 --- a/tutorials/examples/train_hypergrid.py +++ b/tutorials/examples/train_hypergrid.py @@ -163,7 +163,7 @@ def _error_checker(self): sys.exit(1) except Exception as e: - print(f'Process {self.rank}: Error in error checker: {str(e)}') + print('Process {}: Error in error checker: {}'.format(self.rank, e)) self.signal_error() break @@ -196,8 +196,8 @@ def _cleanup(self): def gather_distributed_data( - local_data: Union[List, torch.Tensor], world_size: int = None, rank: int = None -) -> List: + local_tensor: torch.Tensor, world_size: int = None, rank: int = None, verbose: bool = False, +) -> torch.Tensor: """ Gather data from all processes in a distributed setting. @@ -207,75 +207,92 @@ def gather_distributed_data( rank: Current process rank (optional, will get from env if None) Returns: - List containing gathered data from all processes + On rank 0: Concatenated tensor from all processes + On other ranks: None """ - print("syncing distributed data") + if verbose: + print("syncing distributed data") if world_size is None: world_size = dist.get_world_size() if rank is None: rank = dist.get_rank() - # Convert data to tensor if it's not already. - if not isinstance(local_data, torch.Tensor): - print("+ converting tensor data via serialization") - # Serialize complex data structures. - serialized_data = pickle.dumps(local_data) - local_tensor = torch.ByteTensor(torch.ByteStorage.from_buffer(serialized_data)) - else: - print("+ tensor not converted") - local_tensor = local_data - - # First gather sizes to allocate correct buffer sizes. - local_size = torch.tensor([local_tensor.numel()], device=local_tensor.device) + # First gather batch_sizes to allocate correct buffer sizes. + local_batch_size = torch.tensor([local_tensor.shape[0]], device=local_tensor.device, dtype=local_tensor.dtype) if rank == 0: - size_list = [ - torch.tensor([0], device=local_tensor.device) for _ in range(world_size) + # Assumes same dimensionality on all ranks! + batch_size_list = [ + torch.zeros((1, ), device=local_tensor.device, dtype=local_tensor.dtype) for _ in range(world_size) ] else: - size_list = None - print("+ all gather of local_size={} to size_list".format(local_size)) - dist.gather(local_size, gather_list=size_list, dst=0) - # 44 [0] Process 0: Caught error: Invalid function argument. Expected parameter `tensor` to be of type torch.Tensor.[0] + batch_size_list = None + + if verbose: + print("rank={}, batch_size_list={}".format(rank, batch_size_list)) + print("+ gather of local_batch_size={} to batch_size_list".format(local_batch_size)) + dist.gather(local_batch_size, gather_list=batch_size_list, dst=0) dist.barrier() # Add synchronization + # Pad local tensor to maximum size. + if verbose: + print("+ padding local tensor") + + if rank == 0: + max_batch_size = (max(bs for bs in batch_size_list)) + else: + max_batch_size = 0 + + state_size = local_tensor.shape[1] # assume states are 1-d, is true for this env. + + # Broadcast max_size to all processes for padding + max_batch_size_tensor = torch.tensor(max_batch_size, device=local_tensor.device) + dist.broadcast(max_batch_size_tensor, src=0) # Pad local tensor to maximum size. - print("+ padding local tensor") - max_size = max(size.item() for size in size_list) - if local_tensor.numel() < max_size: + if local_tensor.shape[0] < max_batch_size: padding = torch.zeros( - max_size - local_tensor.numel(), + (max_batch_size - local_tensor.shape[0], state_size), dtype=local_tensor.dtype, - device=local_tensor.device, + device=local_tensor.device ) - local_tensor = torch.cat((local_tensor, padding)) + local_tensor = torch.cat((local_tensor, padding), dim=0) - # Gather all tensors. - print("+ gathering all tensors from world_size={}".format(world_size)) + # Gather padded tensors. if rank == 0: tensor_list = [ - torch.zeros(max_size, dtype=local_tensor.dtype, device=local_tensor.device) + torch.zeros( + (max_batch_size, state_size), + dtype=local_tensor.dtype, + device=local_tensor.device, + ) for _ in range(world_size) ] else: tensor_list = None + + if verbose: + print("+ gathering all tensors from world_size={}".format(world_size)) + print("rank={}, tensor_list={}".format(rank, tensor_list)) dist.gather(local_tensor, gather_list=tensor_list, dst=0) dist.barrier() # Add synchronization - # Trim padding and deserialize if necessary. - result = [] - for tensor, size in zip(tensor_list, size_list): - trimmed_tensor = tensor[: size.item()] - if not isinstance(local_data, torch.Tensor): - print("+ deserializing tensor.") - # Deserialize data. - result = pickle.loads(trimmed_tensor.cpu().numpy().tobytes()) - else: - print("+ tensor not deserialized") - result = trimmed_tensor + # Only rank 0 processes the results + if rank == 0: + results = [] + for tensor, batch_size in zip(tensor_list, batch_size_list): + trimmed_tensor = tensor[:batch_size.item(), ...] + results.append(trimmed_tensor) - return result + if verbose: + print("distributed n_results={}".format(len(results))) + + for r in results: + print(" {}".format(r.shape)) + + return torch.cat(results, dim=0) # Concatenates along the batch dimension. + + return None # For all non-zero ranks. def main(args): # noqa: C901 @@ -528,7 +545,7 @@ def cleanup(): world_size, cleanup_callback=cleanup, ) - handler.start() + #handler.start() for iteration in trange(n_iterations): @@ -616,6 +633,24 @@ def cleanup(): ] ) + log_this_iter = ((iteration % args.validation_interval == 0) or iteration == n_iterations - 1) + + print("before distributed -- orig_shape={}".format(visited_terminating_states.tensor.shape)) + if args.distributed and log_this_iter: + try: + all_visited_terminating_states = gather_distributed_data( + visited_terminating_states.tensor + ) + except Exception as e: + print('Process {}: Caught error: {}'.format(my_rank, e)) + #handler.signal_error() + sys.exit(1) + else: + all_visited_terminating_states = visited_terminating_states.tensor + + if my_rank == 0: + print("after distributed -- gathered_shape={}, orig_shape={}".format(all_visited_terminating_states.shape, visited_terminating_states.tensor.shape)) + # If we are on the master node, calculate the validation metrics. if my_rank == 0: to_log = { @@ -628,24 +663,12 @@ def cleanup(): "opt_time": opt_time, "rest_time": rest_time, } + if use_wandb: wandb.log(to_log, step=iteration) - if (iteration % args.validation_interval == 0) or ( - iteration == n_iterations - 1 - ): - - if args.distributed: - try: - all_visited_terminating_states = gather_distributed_data( - visited_terminating_states.tensor - ) - except Exception as e: - print(f'Process {my_rank}: Caught error: {str(e)}') - handler.signal_error() - sys.exit(1) - else: - all_visited_terminating_states = visited_terminating_states.tensor + if log_this_iter: + print("logging thjs iteration!") validation_info, discovered_modes = validate_hypergrid( env, gflownet, @@ -725,6 +748,7 @@ def validate_hypergrid( modes_found = set([tuple(s.tolist()) for s in modes]) discovered_modes.update(modes_found) validation_info["n_modes_found"] = len(discovered_modes) + print(len(discovered_modes)) # Old way of counting modes -- potentially buggy - to be removed. # # Add the mode counting metric.