-
Notifications
You must be signed in to change notification settings - Fork 1
/
tweet_sampler.py
46 lines (37 loc) · 1.34 KB
/
tweet_sampler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import random
import numpy as np
import tensorflow as tf
MAX_SAMPLE_LENGTH = 200
class TweetSampler:
def __init__(self, session, model, temperature=1.0):
self.session = session
self.model = model
self.predictions_flat = tf.nn.softmax(model.out_logits / temperature, 1)
def sample(self):
# Start with the start symbol, which has label num_chars
features = [len(self.model.chars)]
tweet = ''
for i in range(MAX_SAMPLE_LENGTH):
next_class = self.sample_next_class(features[-self.model.max_steps:])
if next_class == len(self.model.chars):
break
features.append(next_class)
tweet += self.model.chars[next_class]
return tweet.strip()
def sample_next_class(self, classes):
sample_input = np.zeros([1, self.model.max_steps])
sample_input[:, :len(classes)] = classes
predictions = self.session.run(
self.predictions_flat,
feed_dict={
self.model.features: sample_input,
}
)
probabilities = predictions[len(classes) - 1]
rnd = random.random()
accum = 0
for idx in range(len(probabilities)):
accum += probabilities[idx]
if accum >= rnd:
return idx
return np.argmax(classes)