Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Jul 18, 2024
1 parent fe9a248 commit c97e884
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 22 deletions.
7 changes: 4 additions & 3 deletions src/gfn/containers/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,14 @@ class PrioritizedReplayBuffer(ReplayBuffer):
p_norm_distance: p-norm distance value to pass to torch.cdist, for the
determination of novel states.
"""

def __init__(
self,
env: Env,
objects_type: Literal["transitions", "trajectories", "states"],
capacity: int = 1000,
cutoff_distance: float = 0.,
p_norm_distance: float = 1.,
cutoff_distance: float = 0.0,
p_norm_distance: float = 1.0,
):
"""Instantiates a prioritized replay buffer.
Args:
Expand All @@ -137,7 +138,7 @@ def __init__(
norms are >= 0).
p_norm_distance: p-norm distance value to pass to torch.cdist, for the
determination of novel states.
"""
"""
super().__init__(env, objects_type, capacity)
self.cutoff_distance = cutoff_distance
self.p_norm_distance = p_norm_distance
Expand Down
53 changes: 34 additions & 19 deletions tutorials/examples/train_hypergrid_multinode.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from argparse import ArgumentParser

# didnt help.
#import torch.multiprocessing
#torch.multiprocessing.set_sharing_strategy('file_system')
# import torch.multiprocessing
# torch.multiprocessing.set_sharing_strategy('file_system')

import torch
import time
Expand All @@ -42,8 +42,9 @@
DEFAULT_SEED = 4444


def dist_init(dist_backend : str = "ccl"):
def dist_init(dist_backend: str = "ccl"):
import os

global my_rank
global my_size
print("PMI_SIZE={}".format(int(os.environ.get("PMI_SIZE", "0"))))
Expand All @@ -65,14 +66,13 @@ def dist_init(dist_backend : str = "ccl"):
try:
import torch_mpi
except ImportError as e:
raise Exception ("import torch_mpi failed, {}".format(e))
raise Exception("import torch_mpi failed, {}".format(e))
else:
raise Exception(f"invalid backend requested: {dist_backend}")


os.environ["RANK"] = os.environ.get("PMI_RANK", "0")
os.environ["WORLD_SIZE"] = os.environ.get("PMI_SIZE", "1")
print("+ OMP_NUM_THREADS = ", os.getenv('OMP_NUM_THREADS'))
print("+ OMP_NUM_THREADS = ", os.getenv("OMP_NUM_THREADS"))
dist.init_process_group(
backend=dist_backend,
init_method="env://",
Expand All @@ -94,6 +94,7 @@ def main(args): # noqa: C901
use_wandb = len(args.wandb_project) > 0
if use_wandb:
import wandb

wandb.init(project=args.wandb_project)
wandb.config.update(args)

Expand Down Expand Up @@ -289,8 +290,8 @@ def main(args): # noqa: C901
loss_backward_time = 0
opt_time = 0
rest_time = 0
print ("n_iterations = ", n_iterations)
print ("my_batch_size = ", my_batch_size)
print("n_iterations = ", n_iterations)
print("my_batch_size = ", my_batch_size)
time_start = time.time()
discovered_modes = set()

Expand All @@ -300,11 +301,11 @@ def main(args): # noqa: C901
env, n_samples=my_batch_size, sample_off_policy=off_policy_sampling
)
sample_end = time.time()
sample_time += (sample_end - sample_start)
sample_time += sample_end - sample_start
to_train_samples_start = time.time()
training_samples = gflownet.to_training_samples(trajectories)
to_train_samples_end = time.time()
to_train_samples_time += (to_train_samples_end - to_train_samples_start)
to_train_samples_time += to_train_samples_end - to_train_samples_start
if replay_buffer is not None:
with torch.no_grad():
replay_buffer.add(training_samples)
Expand All @@ -316,15 +317,15 @@ def main(args): # noqa: C901
loss_start = time.time()
loss = gflownet.loss(env, training_objects)
loss_end = time.time()
loss_time += (loss_end - loss_start)
loss_time += loss_end - loss_start
loss_backward_start = time.time()
loss.backward()
loss_backward_end = time.time()
loss_backward_time += (loss_backward_end - loss_backward_start)
loss_backward_time += loss_backward_end - loss_backward_start
opt_start = time.time()
optimizer.step()
opt_end = time.time()
opt_time += (opt_end - opt_start)
opt_time += opt_end - opt_start

visited_terminating_states.extend(trajectories.last_states)

Expand All @@ -334,7 +335,9 @@ def main(args): # noqa: C901
to_log = {"loss": loss.item(), "states_visited": states_visited}
if use_wandb:
wandb.log(to_log, step=iteration)
if (iteration % args.validation_interval == 0) or (iteration == n_iterations - 1):
if (iteration % args.validation_interval == 0) or (
iteration == n_iterations - 1
):
validation_info, discovered_modes = validate_hypergrid(
env,
gflownet,
Expand All @@ -349,11 +352,23 @@ def main(args): # noqa: C901

time_end = time.time()
total_time = time_end - time_start
rest_time = total_time - (sample_time + to_train_samples_time + loss_time + loss_backward_time + opt_time)
rest_time = total_time - (
sample_time + to_train_samples_time + loss_time + loss_backward_time + opt_time
)
dist.barrier()
if (my_rank == 0):
print ("total_time, sample_time, to_train_samples_time, loss_time, loss_backward_time, opt_time, rest_time")
print (total_time, sample_time, to_train_samples_time, loss_time, loss_backward_time, opt_time, rest_time)
if my_rank == 0:
print(
"total_time, sample_time, to_train_samples_time, loss_time, loss_backward_time, opt_time, rest_time"
)
print(
total_time,
sample_time,
to_train_samples_time,
loss_time,
loss_backward_time,
opt_time,
rest_time,
)

try:
return validation_info["l1_dist"]
Expand All @@ -378,7 +393,7 @@ def validate_hypergrid(

# # Add the mode counting metric.
states, scale = visited_terminating_states.tensor, env.scale_factor
mode_reward_threshold = 1. # Assumes height >= 5. TODO - verify.
mode_reward_threshold = 1.0 # Assumes height >= 5. TODO - verify.

# # Modes will have a reward greater than 1.
modes = states[env.reward(states) >= mode_reward_threshold]
Expand Down

0 comments on commit c97e884

Please sign in to comment.