-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathuse.py
41 lines (36 loc) · 1.7 KB
/
use.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import tensorflow_hub as hub
import os
import numpy as np
class USE(object):
def __init__(self, cache_path='./cache'):
super(USE, self).__init__()
os.environ['TFHUB_CACHE_DIR'] = cache_path
module_url = "https://tfhub.dev/google/universal-sentence-encoder-large/5"
self.embed = hub.load(module_url)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
self.sess = tf.Session(config=config)
self.build_graph()
self.sess.run([tf.global_variables_initializer(), tf.tables_initializer()])
def build_graph(self):
self.sts_input1 = tf.placeholder(tf.string, shape=(None))
self.sts_input2 = tf.placeholder(tf.string, shape=(None))
sts_encode1 = tf.nn.l2_normalize(self.embed(self.sts_input1), axis=1)
sts_encode2 = tf.nn.l2_normalize(self.embed(self.sts_input2), axis=1)
self.cosine_similarities = tf.reduce_sum(tf.multiply(sts_encode1, sts_encode2), axis=1)
clip_cosine_similarities = tf.clip_by_value(self.cosine_similarities, -1.0, 1.0)
self.sim_scores = 1.0 - tf.acos(clip_cosine_similarities)
def semantic_sim(self, sents1, sents2):
scores = self.sess.run(
[self.sim_scores],
feed_dict={
self.sts_input1: sents1,
self.sts_input2: sents2,
})
return np.array(scores)[0]
if __name__ == "__main__":
use = USE('./cache')
scores = use.semantic_sim(["Food here was great, but service was bad", "Food here was great, but service was bad"], ["Food here was bad, but service was nice", "Food here was great, but service was bad"])
print(scores)