-
Notifications
You must be signed in to change notification settings - Fork 68
/
sample_rnn.py
executable file
·68 lines (55 loc) · 1.62 KB
/
sample_rnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from __future__ import print_function
import os
import sys
import time
import importlib
if sys.version_info < (3,0):
import cPickle as pickle
else:
import pickle
from folk_rnn import Folk_RNN
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('metadata_path')
parser.add_argument('--rng_seed', type=int)
parser.add_argument('--temperature', type=float)
parser.add_argument('--ntunes', type=int, default=1)
parser.add_argument('--seed')
parser.add_argument('--terminal', action="store_true")
args = parser.parse_args()
metadata_path = args.metadata_path
rng_seed = args.rng_seed
temperature = args.temperature
ntunes = args.ntunes
seed = args.seed
print('seed', seed)
with open(metadata_path) as f:
metadata = pickle.load(f)
config = importlib.import_module('configurations.%s' % metadata['configuration'])
# samples dir
if not os.path.isdir('samples'):
os.makedirs('samples')
target_path = "samples/%s-s%d-%.2f-%s.txt" % (
metadata['experiment_id'], rng_seed, temperature, time.strftime("%Y%m%d-%H%M%S", time.localtime()))
if config.one_hot:
config.embedding_size = None
folk_rnn = Folk_RNN(
metadata['token2idx'],
metadata['param_values'],
config.num_layers,
config.rnn_size,
config.grad_clipping,
config.dropout,
config.embedding_size,
rng_seed,
temperature
)
folk_rnn.seed_tune(seed)
for i in xrange(ntunes):
tune = 'X:{}\n{}\n'.format(i, folk_rnn.compose_tune())
if args.terminal:
print(tune)
else:
with open(target_path, 'a+') as f:
f.write(tune)
print('Saved to {}'.format(target_path))