Skip to content

Commit

Permalink
Use trainer.predict to run inference
Browse files Browse the repository at this point in the history
  • Loading branch information
yanbing-j committed Nov 30, 2022
1 parent a842429 commit d6d0fd0
Showing 1 changed file with 32 additions and 4 deletions.
36 changes: 32 additions & 4 deletions examples/lsc/mag240m/gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from typing import Dict, Tuple
from torch_geometric.data import Batch
from torch_geometric.data import LightningNodeData
import pathlib
from torch.profiler import ProfilerActivity, profile

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Expand Down Expand Up @@ -69,7 +71,6 @@ def __init__(self, model: str, in_channels: int, out_channels: int,
self.conv1 = SAGEConv(in_channels, hidden_channels)
self.conv2 = SAGEConv(hidden_channels, hidden_channels)
self.lin = Linear(hidden_channels, out_channels)
# self.relu = ReLU(inplace=True)

def forward(self, x: Tensor, edge_index: Adj) -> Tensor:
x = x.to(torch.float)
Expand Down Expand Up @@ -124,11 +125,25 @@ def test_step(self, batch: Batch, batch_idx: int):
self.log('test_acc', self.test_acc, on_step=False, on_epoch=True,
prog_bar=True, sync_dist=True)

def predict_step(self, batch: Batch, batch_idx: int):
y_hat, y = self.common_step(batch)
return y_hat

def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
scheduler = StepLR(optimizer, step_size=25, gamma=0.25)
return [optimizer], [scheduler]

def trace_handler(p):
if torch.cuda.is_available():
profile_sort = 'self_cuda_time_total'
else:
profile_sort = 'self_cpu_time_total'
output = p.key_averages().table(sort_by=profile_sort)
print(output)
profile_dir = str(pathlib.Path.cwd()) + '/'
timeline_file = profile_dir + 'timeline' + '.json'
p.export_chrome_trace(timeline_file)

if __name__ == '__main__':
parser = argparse.ArgumentParser()
Expand All @@ -142,14 +157,20 @@ def configure_optimizers(self):
parser.add_argument('--in-memory', action='store_true')
parser.add_argument('--device', type=str, default='0')
parser.add_argument('--evaluate', action='store_true')
parser.add_argument('--profile', action='store_true')
args = parser.parse_args()
args.sizes = [int(i) for i in args.sizes.split('-')]
print(args)

seed_everything(42)
dataset = MAG240MDataset(ROOT)
data = dataset.to_pyg_hetero_data()
datamodule = MAG240M(data, loader='neighbor', num_neighbors=args.sizes, batch_size=args.batch_size, num_workers=2)
datamodule = MAG240M(data, ('paper', data['paper'].train_mask),
('paper', data['paper'].val_mask),
('paper', data['paper'].test_mask),
('paper', data['paper'].test_mask),
loader='neighbor', num_neighbors=args.sizes,
batch_size=args.batch_size, num_workers=2)
print(datamodule)

if not args.evaluate:
Expand All @@ -160,7 +181,9 @@ def configure_optimizers(self):
checkpoint_callback = ModelCheckpoint(monitor='val_acc', mode = 'max', save_top_k=1)
trainer = Trainer(accelerator="cpu", max_epochs=args.epochs,
callbacks=[checkpoint_callback],
default_root_dir=f'logs/{args.model}')
default_root_dir=f'logs/{args.model}',
limit_train_batches=10, limit_test_batches=10,
limit_val_batches=10, limit_predict_batches=10)
trainer.fit(model, datamodule=datamodule)

if args.evaluate:
Expand All @@ -177,7 +200,12 @@ def configure_optimizers(self):
datamodule.batch_size = 16
datamodule.sizes = [160] * len(args.sizes) # (Almost) no sampling...

trainer.test(model=model, datamodule=datamodule)
trainer.predict(model=model, datamodule=datamodule)
if args.profile:
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
on_trace_ready=trace_handler) as p:
trainer.predict(model=model, datamodule=datamodule)
p.step()

# evaluator = MAG240MEvaluator()
# loader = datamodule.hidden_test_dataloader()
Expand Down

0 comments on commit d6d0fd0

Please sign in to comment.