-
Notifications
You must be signed in to change notification settings - Fork 77
/
main_hyperparams.py
120 lines (98 loc) · 4.35 KB
/
main_hyperparams.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# @Author : bamtercelboo
# @Datetime : 2018/1/30 19:50
# @File : main_hyperparams.py.py
# @Last Modify Time : 2018/1/30 19:50
# @Contact : bamtercelboo@{gmail.com, 163.com}
"""
FILE : main_hyperparams.py.py
FUNCTION : main
"""
import argparse
import datetime
import Config.config as configurable
from DataUtils.Alphabet import *
from DataUtils.Batch_Iterator import *
from Dataloader import DataLoader_NER
from DataUtils.Load_Pretrained_Embed import *
from DataUtils.Common import seed_num, paddingkey
from models.BiLSTM import *
import train
import random
import shutil
# solve default encoding problem
from imp import reload
defaultencoding = 'utf-8'
if sys.getdefaultencoding() != defaultencoding:
reload(sys)
sys.setdefaultencoding(defaultencoding)
# random seed
torch.manual_seed(seed_num)
random.seed(seed_num)
# load data / create alphabet / create iterator
def load_Data(config):
print("Loading Data......")
# read file
data_loader = DataLoader_NER.DataLoader()
train_data, dev_data, test_data = data_loader.dataLoader(path=[config.train_file, config.dev_file, config.test_file],
shuffle=config.shuffle)
print("train sentence {}, dev sentence {}, test sentence {}.".format(len(train_data), len(dev_data), len(test_data)))
# create the alphabet
create_alphabet = CreateAlphabet(min_freq=config.min_freq)
if config.embed_finetune is False:
create_alphabet.build_vocab(train_data=train_data, dev_data=dev_data, test_data=test_data)
if config.embed_finetune is True:
create_alphabet.build_vocab(train_data=train_data)
# create iterator
create_iter = Iterators()
train_iter, dev_iter, test_iter = create_iter.createIterator(
# batch_size=[config.batch_size, len(dev_data), len(test_data)],
batch_size=[config.batch_size, config.dev_batch_size, config.test_batch_size],
data=[train_data, dev_data, test_data], operator=create_alphabet,
config=config)
return train_iter, dev_iter, test_iter, create_alphabet
def main():
# save file
config.mulu = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
config.save_dir = os.path.join(config.save_direction, config.mulu)
if not os.path.isdir(config.save_dir):
os.makedirs(config.save_dir)
# get iter
create_alphabet = None
train_iter, dev_iter, test_iter, create_alphabet = load_Data(config)
config.embed_num = create_alphabet.word_alphabet.vocab_size
config.class_num = create_alphabet.label_alphabet.vocab_size
config.paddingId = create_alphabet.word_paddingId
config.label_paddingId = create_alphabet.label_paddingId
config.create_alphabet = create_alphabet
print("embed_num : {}, class_num : {}".format(config.embed_num, config.class_num))
print("PaddingID {}".format(config.paddingId))
if config.pretrained_embed:
print("Using Pre_Trained Embedding.")
pretrain_embed = load_pretrained_emb_zeros(path=config.pretrained_embed_file,
text_field_words_dict=create_alphabet.word_alphabet.id2words,
pad=paddingkey)
config.pretrained_weight = pretrain_embed
model = None
if config.model_BiLstm is True:
print("loading model.....")
model = BiLSTM(config)
print(model)
if config.use_cuda is True:
model = model.cuda()
print("Training Start......")
train.train(train_iter=train_iter, dev_iter=dev_iter, test_iter=test_iter, model=model, config=config)
# train.train(train_iter=train_iter, dev_iter=train_iter, test_iter=train_iter, model=model, config=config)
if __name__ == "__main__":
print("Process ID {}, Process Parent ID {}".format(os.getpid(), os.getppid()))
parser = argparse.ArgumentParser(description="Chinese NER & POS")
parser.add_argument('--config_file', default="./Config/config.cfg")
args = parser.parse_args()
config = configurable.Configurable(config_file=args.config_file)
if config.use_cuda is True:
print("Using GPU To Train......")
# torch.backends.cudnn.enabled = True
# torch.backends.cudnn.deterministic = True
torch.cuda.manual_seed(seed_num)
torch.cuda.manual_seed_all(seed_num)
print("torch.cuda.initial_seed", torch.cuda.initial_seed())
main()