Skip to content

Commit

Permalink
loss computation fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Aug 30, 2024
1 parent af4e52d commit cedc9b0
Showing 1 changed file with 19 additions and 2 deletions.
21 changes: 19 additions & 2 deletions tutorials/examples/train_hypergrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@
DEFAULT_SEED = 4444


def average_gradients(model):
"""All-Reduce gradients across all models."""
size = float(dist.get_world_size())
for param in model.parameters():
dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
param.grad.data /= size


def initalize_distributed_compute(dist_backend: str = "ccl"):
"""Initalizes distributed compute using either ccl or mpi backends."""
global my_rank # TODO: remove globals?
Expand Down Expand Up @@ -333,7 +341,7 @@ def main(args): # noqa: C901

trajectories = gflownet.sample_trajectories(
env,
n_samples=args.batch_size,
n_samples=per_node_batch_size, # Split batch across all workers.
save_logprobs=is_on_policy,
save_estimator_outputs=False,
)
Expand Down Expand Up @@ -361,11 +369,20 @@ def main(args): # noqa: C901

# Time the loss computation
loss_start = time.time()
loss = gflownet.loss(env, training_objects)
loss = gflownet.loss(
env,
training_objects,
reduction="sum" if args.distributed else "mean",
)
loss_end = time.time()
loss_time = loss_end - loss_start
total_loss_time += loss_time

# Normalize the loss by the local batch size if distributed
if args.distributed:
loss = loss / (per_node_batch_size)


# Time backpropagation computation.
loss_backward_start = time.time()
loss.backward()
Expand Down

0 comments on commit cedc9b0

Please sign in to comment.