-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
54 lines (43 loc) · 1.83 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import dgl
import torch
import torch_geometric
import numpy as np
import argparse
from GDSS.parsers.config import get_config
from GDSS.trainer import Trainer
from GDSS.utils.data_loader import dataloader
from data import AnomalyDataset
from anomaly_scores import save_final_scores
def main(args):
config = get_config(args.config, args.seed)
exp_name = args.exp_name
trajectory_sample = args.trajectory_sample
num_sample = args.num_sample
# Load dataset
dataset_name = config.data.data
dataset = AnomalyDataset(dataset_name, radius=1, undirected=config.model.sym)
print(f'Dataset: {dataset_name}')
print(f'Number of nodes: {len(dataset)}')
# Adjust max node num
config.data.max_node_num = dataset.max_node_num
print(f'Max size subgraphs (95% quantile): {config.data.max_node_num}')
# Adjust feature dimension
config.data.max_feat_num = dataset.feat_dim
print(f'Feature dimension: {config.data.max_feat_num}')
# Train GDSS
trainer = Trainer(config)
train_loader = dataloader(config, dataset)
trainer.train_loader = train_loader
ckpt = trainer.train(exp_name)
config.ckpt = ckpt
# Inference
save_final_scores(config, dataset, exp_name, trajectory_sample, num_sample)
if __name__=="__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, required=True, help='config path')
parser.add_argument('--exp_name', type=str, required=True, help='experiment name')
parser.add_argument('--trajectory_sample', type=int, default=1, required=False, help='number of samples per trajectory')
parser.add_argument('--num_sample', type=int, default=1, required=False, help='number of samples per node')
parser.add_argument('--seed', type=int, default=42, required=False, help='rng seed value')
args = parser.parse_args()
main(args)