Skip to content

Commit

Permalink
Added support for PyTorch Profiler
Browse files Browse the repository at this point in the history
  • Loading branch information
sanchit-misra committed Aug 21, 2024
1 parent d1b9e29 commit b5c041c
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions tutorials/examples/train_hypergrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from gfn.utils.common import set_seed
from gfn.utils.modules import DiscreteUniform, NeuralNet, Tabular
from gfn.utils.training import validate
from torch.profiler import profile, ProfilerActivity

DEFAULT_SEED = 4444

Expand Down Expand Up @@ -237,7 +238,20 @@ def main(args): # noqa: C901
discovered_modes = set()
is_on_policy = args.replay_buffer_size == 0

if (args.profile):
keep_active = args.trajectories_to_profile//args.batch_size
prof = profile(
schedule=torch.profiler.schedule(wait=1, warmup=1, active=keep_active, repeat=1),
activities=[ProfilerActivity.CPU],
record_shapes=True,
with_stack=True
)
prof.start()
for iteration in trange(n_iterations):
if (args.profile):
prof.step()
if iteration >= 1 + 1 + keep_active:
break
trajectories = gflownet.sample_trajectories(
env,
n_samples=args.batch_size,
Expand Down Expand Up @@ -279,6 +293,10 @@ def main(args): # noqa: C901
to_log.update(validation_info)
tqdm.write(f"{iteration}: {to_log}")

if (args.profile):
prof.stop()
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20))
prof.export_chrome_trace("trace.json")
try:
return validation_info["l1_dist"]
except KeyError:
Expand Down Expand Up @@ -456,6 +474,18 @@ def validate_hypergrid(
action="store_true",
help="Calculates the true partition function.",
)
parser.add_argument(
"--profile",
action="store_true",
help="Profiles the execution using PyTorch Profiler.",
)
parser.add_argument(
"--trajectories_to_profile",
type=int,
default=2048,
help="Number of trajectories to profile using the Pytorch Profiler." +
" Preferably, a multiple of batch size.",
)

args = parser.parse_args()

Expand Down

0 comments on commit b5c041c

Please sign in to comment.