From b5c041cc909fd9f171ea8e61c7aa5039734452dc Mon Sep 17 00:00:00 2001 From: smisra Date: Wed, 21 Aug 2024 04:57:33 -0700 Subject: [PATCH] Added support for PyTorch Profiler --- tutorials/examples/train_hypergrid.py | 30 +++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py index b585a9c7..2e66f93f 100644 --- a/tutorials/examples/train_hypergrid.py +++ b/tutorials/examples/train_hypergrid.py @@ -30,6 +30,7 @@ from gfn.utils.common import set_seed from gfn.utils.modules import DiscreteUniform, NeuralNet, Tabular from gfn.utils.training import validate +from torch.profiler import profile, ProfilerActivity DEFAULT_SEED = 4444 @@ -237,7 +238,20 @@ def main(args): # noqa: C901 discovered_modes = set() is_on_policy = args.replay_buffer_size == 0 + if (args.profile): + keep_active = args.trajectories_to_profile//args.batch_size + prof = profile( + schedule=torch.profiler.schedule(wait=1, warmup=1, active=keep_active, repeat=1), + activities=[ProfilerActivity.CPU], + record_shapes=True, + with_stack=True + ) + prof.start() for iteration in trange(n_iterations): + if (args.profile): + prof.step() + if iteration >= 1 + 1 + keep_active: + break trajectories = gflownet.sample_trajectories( env, n_samples=args.batch_size, @@ -279,6 +293,10 @@ def main(args): # noqa: C901 to_log.update(validation_info) tqdm.write(f"{iteration}: {to_log}") + if (args.profile): + prof.stop() + print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20)) + prof.export_chrome_trace("trace.json") try: return validation_info["l1_dist"] except KeyError: @@ -456,6 +474,18 @@ def validate_hypergrid( action="store_true", help="Calculates the true partition function.", ) + parser.add_argument( + "--profile", + action="store_true", + help="Profiles the execution using PyTorch Profiler.", + ) + parser.add_argument( + "--trajectories_to_profile", + type=int, + default=2048, + help="Number of trajectories to profile using the Pytorch Profiler." + + " Preferably, a multiple of batch size.", + ) args = parser.parse_args()