-
Notifications
You must be signed in to change notification settings - Fork 1
/
nsi.py
40 lines (33 loc) · 1014 Bytes
/
nsi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#!/usr/bin/env python
import yaml
import logging.config
import click
import collections
import random
import numpy as np
import torch
from src.utils.params import Params
from src.features.build_features import featuregen
from src.models.train_model import train
from src.data.make_dataset import split
@click.group()
@click.option('--config', type=click.Path(exists=True), default='settings.yml')
@click.pass_context
def main(ctx, config):
ctx.obj = collections.namedtuple('Config', ['app_config', 'hparams'])
app_config = Params(config, 'defaults')
hparams = Params(config, 'hparams')
ctx.obj.app_config = app_config
ctx.obj.hparams = hparams
with open('logging_config.yml') as fp:
log_cfg = yaml.safe_load(fp)
logging.config.dictConfig(log_cfg)
random.seed(1037)
np.random.seed(99999)
torch.manual_seed(1504)
torch.cuda.manual_seed(1610)
main.add_command(featuregen)
main.add_command(train)
main.add_command(split)
if __name__ == '__main__':
main()