From d6d0fd0754d7a8024b0dce5d9185eb02aff11cbe Mon Sep 17 00:00:00 2001 From: yanbing-j Date: Tue, 29 Nov 2022 21:19:59 +0800 Subject: [PATCH] Use trainer.predict to run inference --- examples/lsc/mag240m/gnn.py | 36 ++++++++++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/examples/lsc/mag240m/gnn.py b/examples/lsc/mag240m/gnn.py index 7f55532f..816af124 100644 --- a/examples/lsc/mag240m/gnn.py +++ b/examples/lsc/mag240m/gnn.py @@ -31,6 +31,8 @@ from typing import Dict, Tuple from torch_geometric.data import Batch from torch_geometric.data import LightningNodeData +import pathlib +from torch.profiler import ProfilerActivity, profile device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') @@ -69,7 +71,6 @@ def __init__(self, model: str, in_channels: int, out_channels: int, self.conv1 = SAGEConv(in_channels, hidden_channels) self.conv2 = SAGEConv(hidden_channels, hidden_channels) self.lin = Linear(hidden_channels, out_channels) - # self.relu = ReLU(inplace=True) def forward(self, x: Tensor, edge_index: Adj) -> Tensor: x = x.to(torch.float) @@ -124,11 +125,25 @@ def test_step(self, batch: Batch, batch_idx: int): self.log('test_acc', self.test_acc, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) + def predict_step(self, batch: Batch, batch_idx: int): + y_hat, y = self.common_step(batch) + return y_hat + def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=0.001) scheduler = StepLR(optimizer, step_size=25, gamma=0.25) return [optimizer], [scheduler] +def trace_handler(p): + if torch.cuda.is_available(): + profile_sort = 'self_cuda_time_total' + else: + profile_sort = 'self_cpu_time_total' + output = p.key_averages().table(sort_by=profile_sort) + print(output) + profile_dir = str(pathlib.Path.cwd()) + '/' + timeline_file = profile_dir + 'timeline' + '.json' + p.export_chrome_trace(timeline_file) if __name__ == '__main__': parser = argparse.ArgumentParser() @@ -142,6 +157,7 @@ def configure_optimizers(self): parser.add_argument('--in-memory', action='store_true') parser.add_argument('--device', type=str, default='0') parser.add_argument('--evaluate', action='store_true') + parser.add_argument('--profile', action='store_true') args = parser.parse_args() args.sizes = [int(i) for i in args.sizes.split('-')] print(args) @@ -149,7 +165,12 @@ def configure_optimizers(self): seed_everything(42) dataset = MAG240MDataset(ROOT) data = dataset.to_pyg_hetero_data() - datamodule = MAG240M(data, loader='neighbor', num_neighbors=args.sizes, batch_size=args.batch_size, num_workers=2) + datamodule = MAG240M(data, ('paper', data['paper'].train_mask), + ('paper', data['paper'].val_mask), + ('paper', data['paper'].test_mask), + ('paper', data['paper'].test_mask), + loader='neighbor', num_neighbors=args.sizes, + batch_size=args.batch_size, num_workers=2) print(datamodule) if not args.evaluate: @@ -160,7 +181,9 @@ def configure_optimizers(self): checkpoint_callback = ModelCheckpoint(monitor='val_acc', mode = 'max', save_top_k=1) trainer = Trainer(accelerator="cpu", max_epochs=args.epochs, callbacks=[checkpoint_callback], - default_root_dir=f'logs/{args.model}') + default_root_dir=f'logs/{args.model}', + limit_train_batches=10, limit_test_batches=10, + limit_val_batches=10, limit_predict_batches=10) trainer.fit(model, datamodule=datamodule) if args.evaluate: @@ -177,7 +200,12 @@ def configure_optimizers(self): datamodule.batch_size = 16 datamodule.sizes = [160] * len(args.sizes) # (Almost) no sampling... - trainer.test(model=model, datamodule=datamodule) + trainer.predict(model=model, datamodule=datamodule) + if args.profile: + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + on_trace_ready=trace_handler) as p: + trainer.predict(model=model, datamodule=datamodule) + p.step() # evaluator = MAG240MEvaluator() # loader = datamodule.hidden_test_dataloader()