-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathtrain.py
44 lines (36 loc) · 1.19 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import os
from args import parse_args
from dkt.dataloader import Preprocess
from dkt import trainer
from dkt.trainer import update_train_data
import torch
from dkt.utils import setSeeds
import wandb
import json
import argparse
def main(args):
if args.use_wandb:
wandb.login()
wandb.init(project='dkt', config=vars(args))
setSeeds(42)
device = "cuda" if torch.cuda.is_available() else "cpu"
args.device = device
preprocess = Preprocess(args)
preprocess.load_train_data(args.train_file_name,args.valid_file_name)
preprocess.load_valid_data(args.valid_file_name)
train_data = preprocess.get_train_data()
valid_data = preprocess.get_valid_data()
test_data = None
if args.use_pseudo:
preprocess.load_test_data(args.test_file_name)
test_data = preprocess.get_test_data()
if args.model == 'tabnet':
trainer.tabnet_run(args, train_data, valid_data, test_data)
elif args.model == 'lgbm':
trainer.lgbm_run(args)
else:
trainer.run(args, train_data, valid_data, test_data)
if __name__ == "__main__":
args = parse_args(mode='train')
os.makedirs(args.model_dir, exist_ok=True)
main(args)