diff --git a/src/gfn/containers/replay_buffer.py b/src/gfn/containers/replay_buffer.py index 2cf9fcc6..a5205249 100644 --- a/src/gfn/containers/replay_buffer.py +++ b/src/gfn/containers/replay_buffer.py @@ -117,13 +117,14 @@ class PrioritizedReplayBuffer(ReplayBuffer): p_norm_distance: p-norm distance value to pass to torch.cdist, for the determination of novel states. """ + def __init__( self, env: Env, objects_type: Literal["transitions", "trajectories", "states"], capacity: int = 1000, - cutoff_distance: float = 0., - p_norm_distance: float = 1., + cutoff_distance: float = 0.0, + p_norm_distance: float = 1.0, ): """Instantiates a prioritized replay buffer. Args: @@ -137,7 +138,7 @@ def __init__( norms are >= 0). p_norm_distance: p-norm distance value to pass to torch.cdist, for the determination of novel states. - """ + """ super().__init__(env, objects_type, capacity) self.cutoff_distance = cutoff_distance self.p_norm_distance = p_norm_distance diff --git a/tutorials/examples/train_hypergrid_multinode.py b/tutorials/examples/train_hypergrid_multinode.py index 30aa8ba2..e204f9bf 100644 --- a/tutorials/examples/train_hypergrid_multinode.py +++ b/tutorials/examples/train_hypergrid_multinode.py @@ -14,8 +14,8 @@ from argparse import ArgumentParser # didnt help. -#import torch.multiprocessing -#torch.multiprocessing.set_sharing_strategy('file_system') +# import torch.multiprocessing +# torch.multiprocessing.set_sharing_strategy('file_system') import torch import time @@ -42,8 +42,9 @@ DEFAULT_SEED = 4444 -def dist_init(dist_backend : str = "ccl"): +def dist_init(dist_backend: str = "ccl"): import os + global my_rank global my_size print("PMI_SIZE={}".format(int(os.environ.get("PMI_SIZE", "0")))) @@ -65,14 +66,13 @@ def dist_init(dist_backend : str = "ccl"): try: import torch_mpi except ImportError as e: - raise Exception ("import torch_mpi failed, {}".format(e)) + raise Exception("import torch_mpi failed, {}".format(e)) else: raise Exception(f"invalid backend requested: {dist_backend}") - os.environ["RANK"] = os.environ.get("PMI_RANK", "0") os.environ["WORLD_SIZE"] = os.environ.get("PMI_SIZE", "1") - print("+ OMP_NUM_THREADS = ", os.getenv('OMP_NUM_THREADS')) + print("+ OMP_NUM_THREADS = ", os.getenv("OMP_NUM_THREADS")) dist.init_process_group( backend=dist_backend, init_method="env://", @@ -94,6 +94,7 @@ def main(args): # noqa: C901 use_wandb = len(args.wandb_project) > 0 if use_wandb: import wandb + wandb.init(project=args.wandb_project) wandb.config.update(args) @@ -289,8 +290,8 @@ def main(args): # noqa: C901 loss_backward_time = 0 opt_time = 0 rest_time = 0 - print ("n_iterations = ", n_iterations) - print ("my_batch_size = ", my_batch_size) + print("n_iterations = ", n_iterations) + print("my_batch_size = ", my_batch_size) time_start = time.time() discovered_modes = set() @@ -300,11 +301,11 @@ def main(args): # noqa: C901 env, n_samples=my_batch_size, sample_off_policy=off_policy_sampling ) sample_end = time.time() - sample_time += (sample_end - sample_start) + sample_time += sample_end - sample_start to_train_samples_start = time.time() training_samples = gflownet.to_training_samples(trajectories) to_train_samples_end = time.time() - to_train_samples_time += (to_train_samples_end - to_train_samples_start) + to_train_samples_time += to_train_samples_end - to_train_samples_start if replay_buffer is not None: with torch.no_grad(): replay_buffer.add(training_samples) @@ -316,15 +317,15 @@ def main(args): # noqa: C901 loss_start = time.time() loss = gflownet.loss(env, training_objects) loss_end = time.time() - loss_time += (loss_end - loss_start) + loss_time += loss_end - loss_start loss_backward_start = time.time() loss.backward() loss_backward_end = time.time() - loss_backward_time += (loss_backward_end - loss_backward_start) + loss_backward_time += loss_backward_end - loss_backward_start opt_start = time.time() optimizer.step() opt_end = time.time() - opt_time += (opt_end - opt_start) + opt_time += opt_end - opt_start visited_terminating_states.extend(trajectories.last_states) @@ -334,7 +335,9 @@ def main(args): # noqa: C901 to_log = {"loss": loss.item(), "states_visited": states_visited} if use_wandb: wandb.log(to_log, step=iteration) - if (iteration % args.validation_interval == 0) or (iteration == n_iterations - 1): + if (iteration % args.validation_interval == 0) or ( + iteration == n_iterations - 1 + ): validation_info, discovered_modes = validate_hypergrid( env, gflownet, @@ -349,11 +352,23 @@ def main(args): # noqa: C901 time_end = time.time() total_time = time_end - time_start - rest_time = total_time - (sample_time + to_train_samples_time + loss_time + loss_backward_time + opt_time) + rest_time = total_time - ( + sample_time + to_train_samples_time + loss_time + loss_backward_time + opt_time + ) dist.barrier() - if (my_rank == 0): - print ("total_time, sample_time, to_train_samples_time, loss_time, loss_backward_time, opt_time, rest_time") - print (total_time, sample_time, to_train_samples_time, loss_time, loss_backward_time, opt_time, rest_time) + if my_rank == 0: + print( + "total_time, sample_time, to_train_samples_time, loss_time, loss_backward_time, opt_time, rest_time" + ) + print( + total_time, + sample_time, + to_train_samples_time, + loss_time, + loss_backward_time, + opt_time, + rest_time, + ) try: return validation_info["l1_dist"] @@ -378,7 +393,7 @@ def validate_hypergrid( # # Add the mode counting metric. states, scale = visited_terminating_states.tensor, env.scale_factor - mode_reward_threshold = 1. # Assumes height >= 5. TODO - verify. + mode_reward_threshold = 1.0 # Assumes height >= 5. TODO - verify. # # Modes will have a reward greater than 1. modes = states[env.reward(states) >= mode_reward_threshold]