Skip to content

Commit

Permalink
synced logic of both scripts (they should be merged next)
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Aug 2, 2024
1 parent 7f49488 commit 2786155
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 22 deletions.
14 changes: 13 additions & 1 deletion tutorials/examples/train_hypergrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,12 +235,13 @@ def main(args): # noqa: C901
n_iterations = args.n_trajectories // args.batch_size
validation_info = {"l1_dist": float("inf")}
discovered_modes = set()
is_on_policy = args.replay_buffer_size == 0

for iteration in trange(n_iterations):
trajectories = gflownet.sample_trajectories(
env,
n_samples=args.batch_size,
save_logprobs=args.replay_buffer_size == 0,
save_logprobs=is_on_policy,
save_estimator_outputs=False,
)
training_samples = gflownet.to_training_samples(trajectories)
Expand Down Expand Up @@ -445,6 +446,17 @@ def validate_hypergrid(
help="Name of the wandb project. If empty, don't use wandb",
)

parser.add_argument(
"--calculate_all_states",
action="store_true",
help="Enumerates all states.",
)
parser.add_argument(
"--calculate_partition",
action="store_true",
help="Calculates the true partition function.",
)

args = parser.parse_args()

print(main(args))
79 changes: 58 additions & 21 deletions tutorials/examples/train_hypergrid_multinode.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,14 @@

from argparse import ArgumentParser

# didnt help.
# import torch.multiprocessing
# torch.multiprocessing.set_sharing_strategy('file_system')

import torch
import time
from tqdm import tqdm, trange
from math import ceil
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

from gfn.containers import ReplayBuffer
from gfn.containers import ReplayBuffer, PrioritizedReplayBuffer
from gfn.gflownet import (
DBGFlowNet,
FMGFlowNet,
Expand Down Expand Up @@ -54,7 +50,8 @@ def dist_init(dist_backend: str = "ccl"):
if dist_backend == "ccl":
print("+ CCL backend requested...")
try:
import oneccl_bindings_for_pytorch
# Note - intel must be imported before oneccl!
import oneccl_bindings_for_pytorch # noqa: F401
except ImportError as e:
raise Exception(
"import oneccl_bindings_for_pytorch failed, {}".format(e)
Expand All @@ -64,7 +61,7 @@ def dist_init(dist_backend: str = "ccl"):
print("+ MPI backend requested...")
assert torch.distributed.is_mpi_available()
try:
import torch_mpi
import torch_mpi # noqa: F401
except ImportError as e:
raise Exception("import torch_mpi failed, {}".format(e))
else:
Expand All @@ -88,7 +85,6 @@ def dist_init(dist_backend: str = "ccl"):
def main(args): # noqa: C901
seed = args.seed if args.seed != 0 else DEFAULT_SEED
set_seed(seed)
off_policy_sampling = False if args.replay_buffer_size == 0 else True
device_str = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"

use_wandb = len(args.wandb_project) > 0
Expand All @@ -98,14 +94,22 @@ def main(args): # noqa: C901
wandb.init(project=args.wandb_project)
wandb.config.update(args)

# Initalize distributed computation.
dist_init()
rank = dist.get_rank()
world_size = torch.distributed.get_world_size()
print(f"Running with DDP on rank {rank}/{world_size}.")

# 1. Create the environment
env = HyperGrid(
args.ndim, args.height, args.R0, args.R1, args.R2, device_str=device_str
args.ndim,
args.height,
args.R0,
args.R1,
args.R2,
device_str=device_str,
calculate_partition=args.calculate_partition,
calculate_all_states=args.calculate_all_states,
)

# 2. Create the gflownets.
Expand All @@ -126,7 +130,10 @@ def main(args): # noqa: C901
hidden_dim=args.hidden_dim,
n_hidden_layers=args.n_hidden,
)

# Prepare the model for data parallel training.
module = DDP(module)

estimator = DiscretePolicyEstimator(
module=module,
n_actions=env.n_actions,
Expand Down Expand Up @@ -165,8 +172,10 @@ def main(args): # noqa: C901
pb_module is not None
), f"pb_module is None. Command-line arguments: {args}"

# Prepare the model for data parallel training.
pf_module = DDP(pf_module)
pb_module = DDP(pb_module)

pf_estimator = DiscretePolicyEstimator(
module=pf_module,
n_actions=env.n_actions,
Expand All @@ -183,7 +192,6 @@ def main(args): # noqa: C901
gflownet = ModifiedDBGFlowNet(
pf_estimator,
pb_estimator,
off_policy_sampling,
)

if args.loss in ("DB", "SubTB"):
Expand All @@ -205,7 +213,9 @@ def main(args): # noqa: C901
n_hidden_layers=args.n_hidden,
torso=pf_module.torso if args.tied else None,
)

# TODO: make module also DDP?

logF_estimator = ScalarEstimator(
module=module, preprocessor=env.preprocessor
)
Expand All @@ -214,28 +224,24 @@ def main(args): # noqa: C901
pf=pf_estimator,
pb=pb_estimator,
logF=logF_estimator,
off_policy=off_policy_sampling,
)
else:
gflownet = SubTBGFlowNet(
pf=pf_estimator,
pb=pb_estimator,
logF=logF_estimator,
off_policy=off_policy_sampling,
weighting=args.subTB_weighting,
lamda=args.subTB_lambda,
)
elif args.loss == "TB":
gflownet = TBGFlowNet(
pf=pf_estimator,
pb=pb_estimator,
off_policy=off_policy_sampling,
)
elif args.loss == "ZVar":
gflownet = LogPartitionVarianceGFlowNet(
pf=pf_estimator,
pb=pb_estimator,
off_policy=off_policy_sampling,
)

assert gflownet is not None, f"No gflownet for loss {args.loss}"
Expand All @@ -251,12 +257,21 @@ def main(args): # noqa: C901
objects_type = "states"
else:
raise NotImplementedError(f"Unknown loss: {args.loss}")
replay_buffer = ReplayBuffer(
env, objects_type=objects_type, capacity=args.replay_buffer_size
)

# 3. Create the optimizer
if args.replay_buffer_prioritized:
replay_buffer = PrioritizedReplayBuffer(
env,
objects_type=objects_type,
capacity=args.replay_buffer_size,
p_norm_distance=1, # Use L1-norm for diversity estimation.
cutoff_distance=0, # -1 turns off diversity-based filtering.
)
else:
replay_buffer = ReplayBuffer(
env, objects_type=objects_type, capacity=args.replay_buffer_size
)

# 3. Create the optimizer
# Policy parameters have their own LR.
params = [
{
Expand All @@ -278,12 +293,17 @@ def main(args): # noqa: C901

optimizer = torch.optim.Adam(params)

print("-A-")
visited_terminating_states = env.states_from_batch_shape((0,))

states_visited = 0
n_iterations = ceil(args.n_trajectories / args.batch_size)
my_batch_size = args.batch_size // world_size
validation_info = {"l1_dist": float("inf")}
discovered_modes = set()
is_on_policy = args.replay_buffer_size == 0

# Distributed computation specific.
my_batch_size = args.batch_size // world_size
sample_time = 0
to_train_samples_time = 0
loss_time = 0
Expand All @@ -293,19 +313,24 @@ def main(args): # noqa: C901
print("n_iterations = ", n_iterations)
print("my_batch_size = ", my_batch_size)
time_start = time.time()
discovered_modes = set()

for iteration in trange(n_iterations):
sample_start = time.time()
trajectories = gflownet.sample_trajectories(
env, n_samples=my_batch_size, sample_off_policy=off_policy_sampling
env,
n_samples=args.batch_size,
save_logprobs=is_on_policy,
save_estimator_outputs=False,
)

sample_end = time.time()
sample_time += sample_end - sample_start

to_train_samples_start = time.time()
training_samples = gflownet.to_training_samples(trajectories)
to_train_samples_end = time.time()
to_train_samples_time += to_train_samples_end - to_train_samples_start

if replay_buffer is not None:
with torch.no_grad():
replay_buffer.add(training_samples)
Expand Down Expand Up @@ -533,6 +558,18 @@ def validate_hypergrid(
help="Name of the wandb project. If empty, don't use wandb",
)


parser.add_argument(
"--calculate_all_states",
action="store_true",
help="Enumerates all states.",
)
parser.add_argument(
"--calculate_partition",
action="store_true",
help="Calculates the true partition function.",
)

args = parser.parse_args()

print(main(args))

0 comments on commit 2786155

Please sign in to comment.