forked from codekansas/keras-language-modeling
-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate_insurance_qa_embeddings.py
executable file
·67 lines (49 loc) · 1.96 KB
/
generate_insurance_qa_embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#!/usr/bin/env python
"""
Command-line script for generating embeddings
Useful if you want to generate larger embeddings for some models
"""
from __future__ import print_function
import os
import sys
import random
import pickle
import argparse
import logging
random.seed(42)
def load(path, name):
return pickle.load(open(os.path.join(path, name), 'rb'))
def revert(vocab, indices):
return [vocab.get(i, 'X') for i in indices]
try:
data_path = os.environ['INSURANCE_QA']
except KeyError:
print('INSURANCE_QA is not set. Set it to your clone of https://github.com/codekansas/insurance_qa_python')
sys.exit(1)
# parse arguments
parser = argparse.ArgumentParser(description='Generate embeddings for the InsuranceQA dataset')
parser.add_argument('--iter', metavar='N', type=int, default=10, help='number of times to run')
parser.add_argument('--size', metavar='D', type=int, default=100, help='dimensions in embedding')
args = parser.parse_args()
# configure logging
logger = logging.getLogger(os.path.basename(sys.argv[0]))
logging.basicConfig(format='%(asctime)s: %(levelname)s: %(message)s')
logging.root.setLevel(level=logging.INFO)
logger.info('running %s' % ' '.join(sys.argv))
# imports go down here because they are time-consuming
from gensim.models import Word2Vec
from keras_models import *
vocab = load(data_path, 'vocabulary')
answers = load(data_path, 'answers')
sentences = [revert(vocab, txt) for txt in answers.values()]
sentences += [revert(vocab, q['question']) for q in load(data_path, 'train')]
# run model
model = Word2Vec(sentences, size=args.size, min_count=5, window=5, sg=1, iter=args.iter)
weights = model.syn0
d = dict([(k, v.index) for k, v in model.vocab.items()])
emb = np.zeros(shape=(len(vocab)+1, args.size), dtype='float32')
for i, w in vocab.items():
if w not in d: continue
emb[i, :] = weights[d[w], :]
np.save(open('word2vec_%d_dim.embeddings' % args.size, 'wb'), emb)
logger.info('saved to "word2vec_%d_dim.embeddings"' % args.size)