-
Notifications
You must be signed in to change notification settings - Fork 112
/
train.py
32 lines (22 loc) · 803 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# -*- coding: utf-8 -*-
# @Time : 2019/8/23 22:00
# @Author : zhoujun
from __future__ import print_function
import os
from utils import load_json
config = load_json('config.json')
os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(i) for i in config['trainer']['gpus']])
from models import get_model, get_loss
from data_loader import get_dataloader
from trainer import Trainer
def main(config):
train_loader = get_dataloader(config['data_loader']['type'], config['data_loader']['args'])
criterion = get_loss(config).cuda()
model = get_model(config)
trainer = Trainer(config=config,
model=model,
criterion=criterion,
train_loader=train_loader)
trainer.train()
if __name__ == '__main__':
main(config)