forked from qiang2100/BERT-LS
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcomplex_word.py
80 lines (52 loc) · 2.07 KB
/
complex_word.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import labeler
import experiment
import collections
import statistics
import pandas as pd
model_path = './gpu_attention.model'
model = labeler.SequenceLabeler.load(model_path)
config = model.config
predictions_cache = {}
id2label = collections.OrderedDict()
for label in model.label2id:
id2label[model.label2id[label]] = label
def get_complex_words(tokenised_string):
dataframe = pd.DataFrame()
dataframe['word'] = tokenised_string
dataframe['binary'] = 'N'
dataframe.to_csv('./'+'complex_word'+'.txt', sep = '\t',index=False, header=False, quotechar=' ')
sentences_test = experiment.read_input_files('./complex_word.txt')
batches_of_sentence_ids = experiment.create_batches_of_sentence_ids(sentences_test, config["batch_equal_size"], config['max_batch_size'])
for sentence_ids_in_batch in batches_of_sentence_ids:
batch = [sentences_test[i] for i in sentence_ids_in_batch]
cost, predicted_labels, predicted_probs = model.process_batch(batch, is_training=False, learningrate=0.0)
try:
assert(len(sentence_ids_in_batch) == len(predicted_labels))
except:
print('cw error')
prob_labels = predicted_probs[0]
probability_list = []
for prob_pair in prob_labels:
probability_list.append(prob_pair[1])
return probability_list
def get_complexities(indexes, tokenized_sentence):
probabilities = get_complex_words(tokenized_sentence)
word_probs = [probabilities[each_index] for each_index in indexes]
return float(sum(word_probs))/len(word_probs)
def get_synonym_complexities(synonyms, tokenized, index):
word_complexities = []
for entry in synonyms:
#index list for multi word replacements
indexes = []
#create copy of original token list
tokenized_sentence = tokenized.copy()
del tokenized_sentence[index]
#if synonym contains multiple words we calculate average complexity of words
for i,word in enumerate(entry):
#insert words
tokenized_sentence.insert((index + i), word)
#append new indexes
indexes.append(index+i)
prob = get_complexities(indexes, tokenized_sentence)
word_complexities.append(prob)
return word_complexities