-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
59 lines (45 loc) · 1.89 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import fasttext
import sys
import ipdb
if __name__ == "__main__":
"""Entry point
Run two pipeline
train:
saves GB sizes model.bin binaries containing models, its weights and word vectors
(no need to save .vec as contained in .bin)
predict:
load model and use to make inferences
"""
arg = sys.argv[1]
if arg == "train":
print("\nSKIPGRAM ======\n")
# train model
skipgram_model = fasttext.train_unsupervised("data/ft.train", model="skipgram")
# save model (.bin)
skipgram_model.save_model("model/skipgram_model.bin")
# list of words in dictionary
print(skipgram_model.words) # list of words in dictionary
print("""King's embedding:\n""") # get the vector of the word 'king'
print(skipgram_model["life"]) # get the vector of the word 'king'
print("\nCBOW ======\n")
# train model
cbow_model = fasttext.train_unsupervised("data/ft.train", model="cbow")
# list of words in dictionary
print(cbow_model.words) # list of words in dictionary
print("""\nKing's embedding:\n""") # get the vector of the word 'king'
print(cbow_model["life"]) # get the vector of the word 'king'
# save model (.bin)
cbow_model.save_model("model/cbow_model.bin")
elif arg == "predict":
print("\SKIPGRAM ======\n")
skipgram_model = fasttext.load_model("model/skipgram_model.bin")
life_vec = skipgram_model.get_word_vector("life")
print(life_vec)
nn = skipgram_model.get_nearest_neighbors("life")
print(nn)
print("\nCBOW ======\n")
cbow_model = fasttext.load_model("model/cbow_model.bin")
life_vec = cbow_model.get_word_vector("life")
print("\n", life_vec, "\n")
nn = cbow_model.get_nearest_neighbors("life")
print("\n", nn, "\n")