-
Notifications
You must be signed in to change notification settings - Fork 24
/
inference.py
80 lines (65 loc) · 2.74 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch
import pandas as pd
from train_util import AddEgoIds, extract_param, add_arange_ids, get_loaders, evaluate_homo, evaluate_hetero
from training import get_model
from torch_geometric.nn import to_hetero, summary
import wandb
import logging
import os
import sys
import time
script_start = time.time()
def infer_gnn(tr_data, val_data, te_data, tr_inds, val_inds, te_inds, args, data_config):
#set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#define a model config dictionary and wandb logging at the same time
wandb.init(
mode="disabled" if args.testing else "online",
project="your_proj_name",
config={
"epochs": args.n_epochs,
"batch_size": args.batch_size,
"model": args.model,
"data": args.data,
"num_neighbors": args.num_neighs,
"lr": extract_param("lr", args),
"n_hidden": extract_param("n_hidden", args),
"n_gnn_layers": extract_param("n_gnn_layers", args),
"loss": "ce",
"w_ce1": extract_param("w_ce1", args),
"w_ce2": extract_param("w_ce2", args),
"dropout": extract_param("dropout", args),
"final_dropout": extract_param("final_dropout", args),
"n_heads": extract_param("n_heads", args) if args.model == 'gat' else None
}
)
config = wandb.config
#set the transform if ego ids should be used
if args.ego:
transform = AddEgoIds()
else:
transform = None
#add the unique ids to later find the seed edges
add_arange_ids([tr_data, val_data, te_data])
tr_loader, val_loader, te_loader = get_loaders(tr_data, val_data, te_data, tr_inds, val_inds, te_inds, transform, args)
#get the model
sample_batch = next(iter(tr_loader))
model = get_model(sample_batch, config, args)
if args.reverse_mp:
model = to_hetero(model, te_data.metadata(), aggr='mean')
if not (args.avg_tps or args.finetune):
command = " ".join(sys.argv)
name = ""
name = '-'.join(name.split('-')[3:])
args.unique_name = name
logging.info("=> loading model checkpoint")
checkpoint = torch.load(f'{data_config["paths"]["model_to_load"]}/checkpoint_{args.unique_name}.tar')
start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
logging.info("=> loaded checkpoint (epoch {})".format(start_epoch))
if not args.reverse_mp:
te_f1, te_prec, te_rec = evaluate_homo(te_loader, te_inds, model, te_data, device, args, precrec=True)
else:
te_f1, te_prec, te_rec = evaluate_hetero(te_loader, te_inds, model, te_data, device, args, precrec=True)
wandb.finish()