forked from sohomghosh/chatbot-retrieval
-
Notifications
You must be signed in to change notification settings - Fork 0
/
udc_predict.py
58 lines (49 loc) · 2.04 KB
/
udc_predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import os
import time
import itertools
import sys
import numpy as np
import tensorflow as tf
import udc_model
import udc_hparams
import udc_metrics
import udc_inputs
from models.dual_encoder import dual_encoder_model
from models.helpers import load_vocab
tf.flags.DEFINE_string("model_dir", None, "Directory to load model checkpoints from")
tf.flags.DEFINE_string("vocab_processor_file", "./data/vocab_processor.bin", "Saved vocabulary processor file")
FLAGS = tf.flags.FLAGS
if not FLAGS.model_dir:
print("You must specify a model directory")
sys.exit(1)
def tokenizer_fn(iterator):
return (x.split(" ") for x in iterator)
# Load vocabulary
vp = tf.contrib.learn.preprocessing.VocabularyProcessor.restore(
FLAGS.vocab_processor_file)
# Load your own data here
INPUT_CONTEXT = "Example context"
POTENTIAL_RESPONSES = ["Response 1", "Response 2"]
def get_features(context, utterance):
context_matrix = np.array(list(vp.transform([context])))
utterance_matrix = np.array(list(vp.transform([utterance])))
context_len = len(context.split(" "))
utterance_len = len(utterance.split(" "))
features = {
"context": tf.convert_to_tensor(context_matrix, dtype=tf.int64),
"context_len": tf.constant(context_len, shape=[1,1], dtype=tf.int64),
"utterance": tf.convert_to_tensor(utterance_matrix, dtype=tf.int64),
"utterance_len": tf.constant(utterance_len, shape=[1,1], dtype=tf.int64),
}
return features, None
if __name__ == "__main__":
hparams = udc_hparams.create_hparams()
model_fn = udc_model.create_model_fn(hparams, model_impl=dual_encoder_model)
estimator = tf.contrib.learn.Estimator(model_fn=model_fn, model_dir=FLAGS.model_dir)
# Ugly hack, seems to be a bug in Tensorflow
# estimator.predict doesn't work without this line
estimator._targets_info = tf.contrib.learn.estimators.tensor_signature.TensorSignature(tf.constant(0, shape=[1,1]))
print("Context: {}".format(INPUT_CONTEXT))
for r in POTENTIAL_RESPONSES:
prob = estimator.predict(input_fn=lambda: get_features(INPUT_CONTEXT, r))
print("{}: {:g}".format(r, prob[0,0]))